Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
d818db9b3b
11 changed files with 307 additions and 19 deletions
  1. +7
    -0
      docs/source/fastNLP.core.callbacks.fitlog_callback.rst
  2. +1
    -0
      docs/source/fastNLP.core.callbacks.rst
  3. +15
    -0
      docs/source/fastNLP.modules.mix_modules.rst
  4. +7
    -0
      docs/source/fastNLP.modules.mix_modules.utils.rst
  5. +15
    -0
      docs/source/fastNLP.modules.rst
  6. +1
    -0
      docs/source/fastNLP.rst
  7. +6
    -6
      fastNLP/core/utils/paddle_utils.py
  8. +1
    -1
      fastNLP/core/utils/rich_progress.py
  9. +2
    -2
      fastNLP/core/utils/torch_utils.py
  10. +10
    -10
      fastNLP/core/utils/utils.py
  11. +242
    -0
      fastNLP/modules/mix_modules/utils.py

+ 7
- 0
docs/source/fastNLP.core.callbacks.fitlog_callback.rst View File

@@ -0,0 +1,7 @@
fastNLP.core.callbacks.fitlog\_callback module
==============================================

.. automodule:: fastNLP.core.callbacks.fitlog_callback
:members:
:undoc-members:
:show-inheritance:

+ 1
- 0
docs/source/fastNLP.core.callbacks.rst View File

@@ -25,6 +25,7 @@ Submodules
fastNLP.core.callbacks.callback_manager
fastNLP.core.callbacks.checkpoint_callback
fastNLP.core.callbacks.early_stop_callback
fastNLP.core.callbacks.fitlog_callback
fastNLP.core.callbacks.has_monitor_callback
fastNLP.core.callbacks.load_best_model_callback
fastNLP.core.callbacks.lr_scheduler_callback


+ 15
- 0
docs/source/fastNLP.modules.mix_modules.rst View File

@@ -0,0 +1,15 @@
fastNLP.modules.mix\_modules package
====================================

.. automodule:: fastNLP.modules.mix_modules
:members:
:undoc-members:
:show-inheritance:

Submodules
----------

.. toctree::
:maxdepth: 4

fastNLP.modules.mix_modules.utils

+ 7
- 0
docs/source/fastNLP.modules.mix_modules.utils.rst View File

@@ -0,0 +1,7 @@
fastNLP.modules.mix\_modules.utils module
=========================================

.. automodule:: fastNLP.modules.mix_modules.utils
:members:
:undoc-members:
:show-inheritance:

+ 15
- 0
docs/source/fastNLP.modules.rst View File

@@ -0,0 +1,15 @@
fastNLP.modules package
=======================

.. automodule:: fastNLP.modules
:members:
:undoc-members:
:show-inheritance:

Subpackages
-----------

.. toctree::
:maxdepth: 4

fastNLP.modules.mix_modules

+ 1
- 0
docs/source/fastNLP.rst View File

@@ -15,3 +15,4 @@ Subpackages
fastNLP.core
fastNLP.envs
fastNLP.io
fastNLP.modules

+ 6
- 6
fastNLP/core/utils/paddle_utils.py View File

@@ -22,9 +22,9 @@ from .utils import apply_to_collection

def _convert_data_device(device: Union[str, int]) -> str:
"""
用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 ``fastNLP`` 会将
用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将
可见的设备保存在 ``USER_CUDA_VISIBLE_DEVICES`` 中,并且将 ``CUDA_VISIBLE_DEVICES`` 设置为可见的第一张显卡;这是为
了顺利执行 ``paddle`` 的分布式训练而设置的。
了顺利执行 **paddle** 的分布式训练而设置的。
在这种情况下,单纯使用 ``driver.data_device`` 是无效的。比如在分布式训练中将设备设置为 ``[0,2,3]`` ,且用户设置了
``CUDA_VISIBLE_DEVICES=3,4,5,6`` ,那么在 ``rank1``的进程中有::
@@ -127,7 +127,7 @@ def get_paddle_device_id(device: Union[str, int]) -> int:

def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> Any:
r"""
``paddle`` 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。
**paddle** 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。

:param batch: 需要进行迁移的数据集合;
:param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串;当这个参数
@@ -145,20 +145,20 @@ def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) ->

def is_in_paddle_dist() -> bool:
"""
判断是否处于 ``paddle`` 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。
判断是否处于 **paddle** 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。
"""
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ)


def is_in_fnlp_paddle_dist() -> bool:
"""
判断是否处于 ``fastNLP`` 拉起的 ``paddle`` 分布式进程中
判断是否处于 **fastNLP** 拉起的 **paddle** 分布式进程中
"""
return FASTNLP_DISTRIBUTED_CHECK in os.environ


def is_in_paddle_launch_dist() -> bool:
"""
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 ``paddle`` 分布式进程中
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中
"""
return FASTNLP_BACKEND_LAUNCH in os.environ

+ 1
- 1
fastNLP/core/utils/rich_progress.py View File

@@ -1,5 +1,5 @@
"""
该文件用于为 ``fastNLP`` 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中
该文件用于为 **fastNLP** 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中
的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突
"""
import sys


+ 2
- 2
fastNLP/core/utils/torch_utils.py View File

@@ -44,11 +44,11 @@ class TorchTransferableDataType(ABC):
def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None,
non_blocking: Optional[bool] = True) -> Any:
r"""
``pytorch`` 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变;
**pytorch** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变;

:param batch: 需要迁移的数据;
:param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作;
:param non_blocking: ``pytorch`` 的数据迁移方法 ``to`` 的参数;
:param non_blocking: **pytorch** 的数据迁移方法 ``to`` 的参数;
:return: 迁移到新设备上的数据集合;
"""
if device is None:


+ 10
- 10
fastNLP/core/utils/utils.py View File

@@ -55,7 +55,7 @@ def get_fn_arg_names(fn: Callable) -> List[str]:
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
r"""
该函数会根据输入函数的形参名从 ``*args`` (均为 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过
该函数会根据输入函数的形参名从 ``*args`` (均为 **dict** 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过
``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为
``value`` 的参数。

@@ -259,21 +259,21 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict:

def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any:
r"""
用来实现将输入的 ``batch`` 或者输出的 ``outputs`` 通过 ``mapping`` 将键值进行更换的功能;
用来实现将输入的 **batch** 或者输出的 **outputs** 通过 ``mapping`` 将键值进行更换的功能;
该函数应用于 ``input_mapping`` 和 ``output_mapping``;

* 对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用;
* 对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step`
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用;
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用;

转换的逻辑按优先级依次为:

1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``
2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``;
1. 如果 ``mapping`` 是一个函数,那么会直接返回 **mapping(data)**
2. 如果 ``mapping`` 是一个 **Dict**,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``;
* 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``
* 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换;
* 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典::
* 如果 ``data`` 是 **Dict**,那么该函数会将 ``data`` 的 ``key`` 替换为 **mapping[key]**
* 如果 ``data`` 是 **dataclass**,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 **Dict**,然后进行转换;
* 如果 ``data`` 是 **Sequence**,那么该函数会先将其转换成一个对应的字典::
{
"_0": list[0],
@@ -281,7 +281,7 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None,
...
}

然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。
然后使用 ``mapping`` 对这个字典进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``'_number'`` 这个形式。

:param mapping: 用于转换的字典或者函数;当 ``mapping`` 是函数时,返回值必须为字典类型;
:param data: 需要被转换的对象;
@@ -459,7 +459,7 @@ def _is_iterable(value):

def pretty_table_printer(dataset_or_ins) -> PrettyTable:
r"""
用于在 ``fastNLP`` 中展示数据的函数::
用于在 **fastNLP** 中展示数据的函数::

>>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
+-----------+-----------+-----------------+


+ 242
- 0
fastNLP/modules/mix_modules/utils.py View File

@@ -0,0 +1,242 @@
import warnings
from typing import Any, Optional, Union

import numpy as np

from fastNLP.core.utils import paddle_to, apply_to_collection
from fastNLP.core.log import logger
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE

if _NEED_IMPORT_PADDLE:
import paddle

if _NEED_IMPORT_JITTOR:
import jittor

if _NEED_IMPORT_TORCH:
import torch

__all__ = [
"paddle2torch",
"torch2paddle",
"jittor2torch",
"torch2jittor",
]

def _paddle2torch(paddle_tensor: 'paddle.Tensor', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor':
"""
将 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` ,并且能够保留梯度进行反向传播

:param paddle_tensor: 要转换的 **paddle** 张量;
:param device: 是否将转换后的张量迁移到特定设备上,为 ``None``时,和输入的张量相同;
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
:return: 转换后的 **torch** 张量;
"""
no_gradient = paddle_tensor.stop_gradient if no_gradient is None else no_gradient
paddle_numpy = paddle_tensor.numpy()
if not np.issubdtype(paddle_numpy.dtype, np.inexact):
no_gradient = True

if device is None:
if paddle_tensor.place.is_gpu_place():
# paddlepaddle有两种Place,对应不同的device id获取方式
if hasattr(paddle_tensor.place, "gpu_device_id"):
# paddle.fluid.core_avx.Place
# 在gpu环境下创建张量的话,张量的place是这一类型
device = f"cuda:{paddle_tensor.place.gpu_device_id()}"
else:
# paddle.CUDAPlace
device = f"cuda:{paddle_tensor.place.get_device_id()}"
else:
# TODO: 可能需要支持xpu等设备
device = "cpu"

if not no_gradient:
# 保持梯度,并保持反向传播
# torch.tensor会保留numpy数组的类型
torch_tensor = torch.tensor(paddle_numpy, requires_grad=True, device=device)
hook = torch_tensor.register_hook(
lambda grad: paddle.autograd.backward(paddle_tensor, paddle.to_tensor(grad.cpu().numpy()))
)
else:
# 不保留梯度
torch_tensor = torch.tensor(paddle_numpy, requires_grad=False, device=device)

return torch_tensor


def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient: bool = None) -> 'paddle.Tensor':
"""
将 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor`,并且能够保留梯度进行反向传播。

:param torch_tensor: 要转换的 **torch** 张量;
:param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,和输入的张量相同;
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
:return: 转换后的 **paddle** 张量;
"""
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient
if device is None:
if torch_tensor.is_cuda:
device = f"gpu:{torch_tensor.device.index}"
else:
device = "cpu"

if not no_gradient:
# 保持梯度并保持反向传播
# paddle的stop_gradient和torch的requires_grad表现是相反的
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False)
hook = paddle_tensor.register_hook(
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy()))
)
else:
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True)

paddle_tensor = paddle_to(paddle_tensor, device)

return paddle_tensor


def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor':
"""
将 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。

:param jittor_var: 要转换的 **jittor** 变量;
:param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,根据 ``jittor.flags.use_cuda`` 决定;
:param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致;
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
:return: 转换后的 **torch** 张量;
"""
# TODO: warning:无法保留梯度
# jittor的grad可以通过callback进行传递
# 如果outputs有_grad键,可以实现求导
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient
if no_gradient == False:
warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.")
jittor_numpy = jittor_var.numpy()
if not np.issubdtype(jittor_numpy.dtype, np.inexact):
no_gradient = True

if device is None:
# jittor的设备分配是自动的
# 根据use_cuda判断
if jittor.flags.use_cuda:
device = "cuda:0"
else:
device = "cpu"

torch_tensor = torch.tensor(jittor_numpy, requires_grad=not no_gradient, device=device)

return torch_tensor


def _torch2jittor(torch_tensor: 'torch.Tensor', no_gradient: bool = None) -> 'jittor.Var':
"""
将 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。

:param torch_tensor: 要转换的 **torch** 张量;
:param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致;
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度;
:return: 转换后的 **jittor** 变量;
"""
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient

if not no_gradient:
# 保持梯度并保持反向传播
jittor_var = jittor.Var(torch_tensor.detach().numpy())
jittor_var.requires_grad = True
hook = jittor_var.register_hook(
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy()))
)
else:
jittor_var = jittor.Var(torch_tensor.detach().numpy())
jittor_var.requires_grad = False

return jittor_var


def torch2paddle(batch: Any, device: str = None, no_gradient: bool = None) -> Any:
"""
递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor` 。

:param batch: 包含 :class:`torch.Tensor` 类型的数据集合
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None`` 时,和输入保持一致;
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度;
:return: 转换后的数据;
"""

return apply_to_collection(
batch,
dtype=torch.Tensor,
function=_torch2paddle,
device=device,
no_gradient=no_gradient,
)


def paddle2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any:
"""
递归地将输入中包含的 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` 。

:param batch: 包含 :class:`paddle.Tensor` 类型的数据集合;
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致;
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度;
:return: 转换后的数据;
"""

return apply_to_collection(
batch,
dtype=paddle.Tensor,
function=_paddle2torch,
device=device,
no_gradient=no_gradient,
)


def jittor2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any:
"""
递归地将输入中包含的 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。

.. note::

注意,由于 **pytorch** 和 **jittor** 之间的差异,从 :class:`jittor.Var` 转换
至 :class:`torch.Tensor` 的过程中无法保留原张量的梯度。

:param batch: 包含 :class:`jittor.Var` 类型的数据集合;
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致;
:param no_gradient: 是否保留原张量的梯度,在这个函数中该参数无效。
:return: 转换后的数据;
"""

return apply_to_collection(
batch,
dtype=jittor.Var,
function=_jittor2torch,
device=device,
no_gradient=no_gradient,
)


def torch2jittor(batch: Any, no_gradient: bool = None) -> Any:
"""
递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。

.. note::

**jittor** 会自动为创建的变量分配设备。

:param batch: 包含 :class:`torch.Tensor` 类型的数据集合;
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致;
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度;
:return: 转换后的数据;
"""
return apply_to_collection(
batch,
dtype=torch.Tensor,
function=_torch2jittor,
no_gradient=no_gradient,
)

Loading…
Cancel
Save