|
|
@@ -10,10 +10,6 @@ from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence |
|
|
|
from typing import Tuple, Optional |
|
|
|
from time import sleep |
|
|
|
|
|
|
|
try: |
|
|
|
from typing import Literal, Final |
|
|
|
except ImportError: |
|
|
|
from typing_extensions import Literal, Final |
|
|
|
import os |
|
|
|
from contextlib import contextmanager |
|
|
|
from functools import wraps |
|
|
@@ -22,7 +18,6 @@ import numpy as np |
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
from fastNLP.core.log import logger |
|
|
|
from ...envs import SUPPORT_BACKENDS |
|
|
|
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
@@ -43,10 +38,10 @@ __all__ = [ |
|
|
|
|
|
|
|
def get_fn_arg_names(fn: Callable) -> List[str]: |
|
|
|
r""" |
|
|
|
返回一个函数的所有参数的名字; |
|
|
|
返回一个函数所有参数的名字 |
|
|
|
|
|
|
|
:param fn: 需要查询的函数; |
|
|
|
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字; |
|
|
|
:param fn: 需要查询的函数 |
|
|
|
:return: 一个列表,其中的元素是函数 ``fn`` 参数的字符串名字 |
|
|
|
""" |
|
|
|
return list(inspect.signature(fn).parameters) |
|
|
|
|
|
|
@@ -54,24 +49,18 @@ def get_fn_arg_names(fn: Callable) -> List[str]: |
|
|
|
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, |
|
|
|
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: |
|
|
|
r""" |
|
|
|
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping |
|
|
|
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 |
|
|
|
该函数会根据输入函数的形参名从 ``*args`` (因此都需要是 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过 |
|
|
|
``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为 |
|
|
|
``value`` 的参数。 |
|
|
|
|
|
|
|
1.该函数用来提供给用户根据字符串匹配从而实现自动调用; |
|
|
|
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; |
|
|
|
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; |
|
|
|
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; |
|
|
|
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; |
|
|
|
|
|
|
|
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; |
|
|
|
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; |
|
|
|
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 |
|
|
|
参数值后,再传给 `fn` 进行实际的运算; |
|
|
|
:param mapping: 一个字典,用来更改其前面的字典的键值; |
|
|
|
|
|
|
|
:return: 返回 `fn` 运行的结果; |
|
|
|
1. 该函数用来提供给用户根据字符串匹配从而实现自动调用; |
|
|
|
2. 注意 ``mapping`` 默认为 ``None``,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 ``mapping`` 为一个字典传入进来; |
|
|
|
如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性; |
|
|
|
3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; |
|
|
|
4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同; |
|
|
|
|
|
|
|
Examples:: |
|
|
|
|
|
|
|
>>> # 1 |
|
|
|
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); |
|
|
|
>>> batch = {"x": 20, "y": 1} |
|
|
@@ -84,6 +73,14 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None |
|
|
|
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 |
|
|
|
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 |
|
|
|
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 |
|
|
|
|
|
|
|
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; |
|
|
|
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数; |
|
|
|
:param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 |
|
|
|
参数值后,再传给 ``fn`` 进行实际的运算; |
|
|
|
:param mapping: 一个字典,用来更改其前面的字典的键值; |
|
|
|
|
|
|
|
:return: 返回 ``fn`` 运行的结果; |
|
|
|
""" |
|
|
|
|
|
|
|
if signature_fn is not None: |
|
|
@@ -226,13 +223,13 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): |
|
|
|
|
|
|
|
def check_user_specific_params(user_params: Dict, fn: Callable): |
|
|
|
""" |
|
|
|
该函数使用用户的输入来对指定函数的参数进行赋值; |
|
|
|
主要用于一些用户无法直接调用函数的情况; |
|
|
|
该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误; |
|
|
|
该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况; |
|
|
|
该函数主要的作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误; |
|
|
|
|
|
|
|
:param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值; |
|
|
|
:param fn: 会被调用的函数; |
|
|
|
:return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值; |
|
|
|
:param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字, |
|
|
|
``value`` 为每一个参数的值; |
|
|
|
:param fn: 将要被调用的函数; |
|
|
|
:return: 返回一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值; |
|
|
|
""" |
|
|
|
|
|
|
|
fn_arg_names = get_fn_arg_names(fn) |
|
|
@@ -243,6 +240,9 @@ def check_user_specific_params(user_params: Dict, fn: Callable): |
|
|
|
|
|
|
|
|
|
|
|
def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: |
|
|
|
""" |
|
|
|
将传入的 `dataclass` 实例转换为字典。 |
|
|
|
""" |
|
|
|
if not is_dataclass(data): |
|
|
|
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") |
|
|
|
_dict = dict() |
|
|
@@ -253,21 +253,31 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: |
|
|
|
|
|
|
|
def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: |
|
|
|
r""" |
|
|
|
用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; |
|
|
|
该函数应用于 `input_mapping` 和 `output_mapping`; |
|
|
|
对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; |
|
|
|
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; |
|
|
|
用来实现将输入的 ``batch``,或者输出的 ``outputs``,通过 ``mapping`` 将键值进行更换的功能; |
|
|
|
该函数应用于 ``input_mapping`` 和 ``output_mapping``; |
|
|
|
|
|
|
|
转换的逻辑按优先级依次为: |
|
|
|
对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用; |
|
|
|
对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step` |
|
|
|
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用; |
|
|
|
|
|
|
|
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; |
|
|
|
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; |
|
|
|
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; |
|
|
|
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; |
|
|
|
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 |
|
|
|
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 |
|
|
|
转换的逻辑按优先级依次为: |
|
|
|
|
|
|
|
:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 |
|
|
|
1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``; |
|
|
|
2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``; |
|
|
|
|
|
|
|
* 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``; |
|
|
|
* 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换; |
|
|
|
* 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典:: |
|
|
|
|
|
|
|
{ |
|
|
|
"_0": list[0], |
|
|
|
"_1": list[1], |
|
|
|
... |
|
|
|
} |
|
|
|
|
|
|
|
然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。 |
|
|
|
|
|
|
|
:param mapping: 用于转换的字典或者函数;``mapping`` 是函数时,返回值必须为字典类型。 |
|
|
|
:param data: 需要被转换的对象; |
|
|
|
:return: 返回转换好的结果; |
|
|
|
""" |
|
|
@@ -320,21 +330,20 @@ def apply_to_collection( |
|
|
|
include_none: bool = True, |
|
|
|
**kwargs: Any, |
|
|
|
) -> Any: |
|
|
|
"""将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 |
|
|
|
|
|
|
|
this function credit to: https://github.com/PyTorchLightning/pytorch-lightning |
|
|
|
Args: |
|
|
|
data: the collection to apply the function to |
|
|
|
dtype: the given function will be applied to all elements of this dtype |
|
|
|
function: the function to apply |
|
|
|
*args: positional arguments (will be forwarded to calls of ``function``) |
|
|
|
wrong_dtype: the given function won't be applied if this type is specified and the given collections |
|
|
|
is of the ``wrong_dtype`` even if it is of type ``dtype`` |
|
|
|
include_none: Whether to include an element if the output of ``function`` is ``None``. |
|
|
|
**kwargs: keyword arguments (will be forwarded to calls of ``function``) |
|
|
|
|
|
|
|
Returns: |
|
|
|
The resulting collection |
|
|
|
""" |
|
|
|
使用函数 ``function`` 递归地在 ``data`` 中的元素执行,但是仅在满足元素为 ``dtype`` 时执行。 |
|
|
|
|
|
|
|
该函数参考了 `pytorch-lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ 的实现 |
|
|
|
|
|
|
|
:param data: 需要进行处理的数据集合或数据 |
|
|
|
:param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据 |
|
|
|
:param function: 对数据进行处理的函数 |
|
|
|
:param args: ``function`` 所需要的其它参数 |
|
|
|
:param wrong_dtype: ``function`` 一定不会生效的数据类型。如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型 |
|
|
|
那么也不会生效。 |
|
|
|
:param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``。 |
|
|
|
:param kwargs: ``function`` 所需要的其它参数 |
|
|
|
:return: 经过 ``function`` 处理后的数据集合 |
|
|
|
""" |
|
|
|
# Breaking condition |
|
|
|
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): |
|
|
@@ -402,16 +411,18 @@ def apply_to_collection( |
|
|
|
@contextmanager |
|
|
|
def nullcontext(): |
|
|
|
r""" |
|
|
|
用来实现一个什么 dummy 的 context 上下文环境; |
|
|
|
实现一个什么都不做的上下文环境 |
|
|
|
""" |
|
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
def sub_column(string: str, c: int, c_size: int, title: str) -> str: |
|
|
|
r""" |
|
|
|
对传入的字符串进行截断,方便在命令行中显示 |
|
|
|
|
|
|
|
:param string: 要被截断的字符串 |
|
|
|
:param c: 命令行列数 |
|
|
|
:param c_size: instance或dataset field数 |
|
|
|
:param c_size: :class:`~fastNLP.core.Instance` 或 :class:`fastNLP.core.DataSet` 的 ``field`` 数目 |
|
|
|
:param title: 列名 |
|
|
|
:return: 对一个过长的列进行截断的结果 |
|
|
|
""" |
|
|
@@ -442,18 +453,17 @@ def _is_iterable(value): |
|
|
|
|
|
|
|
def pretty_table_printer(dataset_or_ins) -> PrettyTable: |
|
|
|
r""" |
|
|
|
:param dataset_or_ins: 传入一个dataSet或者instance |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
在 ``fastNLP`` 中展示数据的函数:: |
|
|
|
|
|
|
|
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) |
|
|
|
>>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) |
|
|
|
+-----------+-----------+-----------------+ |
|
|
|
| field_1 | field_2 | field_3 | |
|
|
|
+-----------+-----------+-----------------+ |
|
|
|
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | |
|
|
|
+-----------+-----------+-----------------+ |
|
|
|
|
|
|
|
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 |
|
|
|
:param dataset_or_ins: 要展示的 :class:`~fastNLP.core.DataSet` 或者 :class:`~fastNLP.core.Instance` |
|
|
|
:return: 根据 ``terminal`` 大小进行自动截断的数据表格 |
|
|
|
""" |
|
|
|
x = PrettyTable() |
|
|
|
try: |
|
|
@@ -486,7 +496,7 @@ def pretty_table_printer(dataset_or_ins) -> PrettyTable: |
|
|
|
|
|
|
|
|
|
|
|
class Option(dict): |
|
|
|
r"""a dict can treat keys as attributes""" |
|
|
|
r"""将键转化为属性的字典类型""" |
|
|
|
|
|
|
|
def __getattr__(self, item): |
|
|
|
try: |
|
|
@@ -516,11 +526,10 @@ _emitted_deprecation_warnings = set() |
|
|
|
|
|
|
|
|
|
|
|
def deprecated(help_message: Optional[str] = None): |
|
|
|
"""Decorator to mark a function as deprecated. |
|
|
|
""" |
|
|
|
标记当前功能已经过时的装饰器。 |
|
|
|
|
|
|
|
Args: |
|
|
|
help_message (`Optional[str]`): An optional message to guide the user on how to |
|
|
|
switch to non-deprecated usage of the library. |
|
|
|
:param help_message: 一段指引信息,告知用户如何将代码切换为当前版本提倡的用法。 |
|
|
|
""" |
|
|
|
|
|
|
|
def decorator(deprecated_function: Callable): |
|
|
@@ -549,11 +558,10 @@ def deprecated(help_message: Optional[str] = None): |
|
|
|
return decorator |
|
|
|
|
|
|
|
|
|
|
|
def seq_len_to_mask(seq_len, max_len=None): |
|
|
|
def seq_len_to_mask(seq_len, max_len: Optional[int]): |
|
|
|
r""" |
|
|
|
|
|
|
|
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 |
|
|
|
转变 1-d seq_len到2-d mask. |
|
|
|
将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。 |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
@@ -570,10 +578,11 @@ def seq_len_to_mask(seq_len, max_len=None): |
|
|
|
>>>print(mask.size()) |
|
|
|
torch.Size([14, 100]) |
|
|
|
|
|
|
|
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,) |
|
|
|
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 |
|
|
|
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 |
|
|
|
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 |
|
|
|
:param seq_len: 大小为是 ``(B,)`` 的长度序列 |
|
|
|
:param int max_len: 将长度 ``pad`` 到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度。 |
|
|
|
但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 |
|
|
|
一个 ``max_len`` 使得 ``mask`` 的长度 ``pad`` 到该长度。 |
|
|
|
:return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` |
|
|
|
""" |
|
|
|
if isinstance(seq_len, np.ndarray): |
|
|
|
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." |
|
|
|