Browse Source

完善文档;更改文档设置,现在可以展示参数的默认值了

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
49cf176e74
42 changed files with 167 additions and 142 deletions
  1. +4
    -1
      docs/Makefile
  2. +7
    -4
      docs/source/conf.py
  3. +2
    -2
      docs/source/fastNLP.core.callbacks.rst
  4. +1
    -1
      docs/source/fastNLP.core.callbacks.torch_callbacks.rst
  5. +1
    -1
      docs/source/fastNLP.core.collators.padders.rst
  6. +2
    -2
      docs/source/fastNLP.core.collators.rst
  7. +1
    -1
      docs/source/fastNLP.core.controllers.loops.rst
  8. +2
    -2
      docs/source/fastNLP.core.controllers.rst
  9. +1
    -1
      docs/source/fastNLP.core.controllers.utils.rst
  10. +1
    -1
      docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst
  11. +1
    -1
      docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst
  12. +7
    -0
      docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst
  13. +3
    -2
      docs/source/fastNLP.core.dataloaders.rst
  14. +1
    -1
      docs/source/fastNLP.core.dataloaders.torch_dataloader.rst
  15. +1
    -1
      docs/source/fastNLP.core.dataset.rst
  16. +1
    -1
      docs/source/fastNLP.core.drivers.jittor_driver.rst
  17. +1
    -1
      docs/source/fastNLP.core.drivers.paddle_driver.rst
  18. +2
    -2
      docs/source/fastNLP.core.drivers.rst
  19. +1
    -1
      docs/source/fastNLP.core.drivers.torch_driver.rst
  20. +1
    -1
      docs/source/fastNLP.core.log.rst
  21. +1
    -1
      docs/source/fastNLP.core.metrics.backend.jittor_backend.rst
  22. +1
    -1
      docs/source/fastNLP.core.metrics.backend.paddle_backend.rst
  23. +2
    -2
      docs/source/fastNLP.core.metrics.backend.rst
  24. +1
    -1
      docs/source/fastNLP.core.metrics.backend.torch_backend.rst
  25. +2
    -2
      docs/source/fastNLP.core.metrics.rst
  26. +2
    -2
      docs/source/fastNLP.core.rst
  27. +1
    -1
      docs/source/fastNLP.core.samplers.rst
  28. +1
    -1
      docs/source/fastNLP.core.utils.rst
  29. +1
    -1
      docs/source/fastNLP.envs.rst
  30. +1
    -1
      docs/source/fastNLP.io.loader.rst
  31. +1
    -1
      docs/source/fastNLP.io.pipe.rst
  32. +2
    -2
      docs/source/fastNLP.io.rst
  33. +1
    -1
      docs/source/fastNLP.rst
  34. +1
    -1
      docs/source/modules.rst
  35. +2
    -2
      fastNLP/core/callbacks/utils.py
  36. +2
    -2
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  37. +2
    -2
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  38. +10
    -9
      fastNLP/core/dataset/dataset.py
  39. +1
    -1
      fastNLP/core/utils/dummy_class.py
  40. +2
    -3
      fastNLP/core/utils/rich_progress.py
  41. +84
    -75
      fastNLP/core/utils/utils.py
  42. +5
    -3
      fastNLP/envs/utils.py

+ 4
- 1
docs/Makefile View File

@@ -16,7 +16,7 @@ help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS)

apidoc:
$(SPHINXAPIDOC) -efM -d 6 -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE)
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE)

server:
cd build/html && python -m http.server $(PORT)
@@ -24,6 +24,9 @@ server:
delete:
rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build

web:
make html && make server

dev:
make delete && make apidoc && make html && make server



+ 7
- 4
docs/source/conf.py View File

@@ -42,7 +42,8 @@ extensions = [
'sphinx.ext.viewcode',
'sphinx.ext.autosummary',
'sphinx.ext.mathjax',
'sphinx.ext.todo'
'sphinx.ext.todo',
'sphinx_autodoc_typehints'
]

autodoc_default_options = {
@@ -53,8 +54,10 @@ autodoc_default_options = {

add_module_names = False
autosummary_ignore_module_all = False
autodoc_typehints = "description"
# autodoc_typehints = "description"
autoclass_content = "class"
typehints_fully_qualified = False
typehints_defaults = "comma"

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
@@ -168,8 +171,8 @@ texinfo_documents = [

# -- Extension configuration -------------------------------------------------
def maybe_skip_member(app, what, name, obj, skip, options):
# if obj.__doc__ is None:
# return True
if obj.__doc__ is None:
return True
if name == "__init__":
return False
if name.startswith("_"):


+ 2
- 2
docs/source/fastNLP.core.callbacks.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.callbacks.torch_callbacks

@@ -18,7 +18,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.callbacks.callback
fastNLP.core.callbacks.callback_event


+ 1
- 1
docs/source/fastNLP.core.callbacks.torch_callbacks.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback
fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback

+ 1
- 1
docs/source/fastNLP.core.collators.padders.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.collators.padders.exceptions
fastNLP.core.collators.padders.get_padder


+ 2
- 2
docs/source/fastNLP.core.collators.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.collators.padders

@@ -18,7 +18,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.collators.collator
fastNLP.core.collators.packer_unpacker

+ 1
- 1
docs/source/fastNLP.core.controllers.loops.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.controllers.loops.evaluate_batch_loop
fastNLP.core.controllers.loops.loop


+ 2
- 2
docs/source/fastNLP.core.controllers.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.controllers.loops
fastNLP.core.controllers.utils
@@ -19,7 +19,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.controllers.evaluator
fastNLP.core.controllers.trainer

+ 1
- 1
docs/source/fastNLP.core.controllers.utils.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.controllers.utils.state
fastNLP.core.controllers.utils.utils

+ 1
- 1
docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst View File

@@ -10,6 +10,6 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.dataloaders.jittor_dataloader.fdl

+ 1
- 1
docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst View File

@@ -10,6 +10,6 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.dataloaders.paddle_dataloader.fdl

+ 7
- 0
docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst View File

@@ -0,0 +1,7 @@
fastNLP.core.dataloaders.prepare\_dataloader module
===================================================

.. automodule:: fastNLP.core.dataloaders.prepare_dataloader
:members:
:undoc-members:
:show-inheritance:

+ 3
- 2
docs/source/fastNLP.core.dataloaders.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.dataloaders.jittor_dataloader
fastNLP.core.dataloaders.paddle_dataloader
@@ -20,7 +20,8 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.dataloaders.mix_dataloader
fastNLP.core.dataloaders.prepare_dataloader
fastNLP.core.dataloaders.utils

+ 1
- 1
docs/source/fastNLP.core.dataloaders.torch_dataloader.rst View File

@@ -10,6 +10,6 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.dataloaders.torch_dataloader.fdl

+ 1
- 1
docs/source/fastNLP.core.dataset.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.dataset.dataset
fastNLP.core.dataset.field


+ 1
- 1
docs/source/fastNLP.core.drivers.jittor_driver.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.drivers.jittor_driver.initialize_jittor_driver
fastNLP.core.drivers.jittor_driver.jittor_driver


+ 1
- 1
docs/source/fastNLP.core.drivers.paddle_driver.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.drivers.paddle_driver.dist_utils
fastNLP.core.drivers.paddle_driver.fleet


+ 2
- 2
docs/source/fastNLP.core.drivers.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.drivers.jittor_driver
fastNLP.core.drivers.paddle_driver
@@ -20,7 +20,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.drivers.choose_driver
fastNLP.core.drivers.driver


+ 1
- 1
docs/source/fastNLP.core.drivers.torch_driver.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.drivers.torch_driver.ddp
fastNLP.core.drivers.torch_driver.dist_utils


+ 1
- 1
docs/source/fastNLP.core.log.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.log.handler
fastNLP.core.log.highlighter


+ 1
- 1
docs/source/fastNLP.core.metrics.backend.jittor_backend.rst View File

@@ -10,6 +10,6 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.metrics.backend.jittor_backend.backend

+ 1
- 1
docs/source/fastNLP.core.metrics.backend.paddle_backend.rst View File

@@ -10,6 +10,6 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.metrics.backend.paddle_backend.backend

+ 2
- 2
docs/source/fastNLP.core.metrics.backend.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.metrics.backend.jittor_backend
fastNLP.core.metrics.backend.paddle_backend
@@ -20,7 +20,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.metrics.backend.auto_backend
fastNLP.core.metrics.backend.backend

+ 1
- 1
docs/source/fastNLP.core.metrics.backend.torch_backend.rst View File

@@ -10,6 +10,6 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.metrics.backend.torch_backend.backend

+ 2
- 2
docs/source/fastNLP.core.metrics.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.metrics.backend

@@ -18,7 +18,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.metrics.accuracy
fastNLP.core.metrics.classify_f1_pre_rec_metric


+ 2
- 2
docs/source/fastNLP.core.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.callbacks
fastNLP.core.collators
@@ -27,6 +27,6 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.vocabulary

+ 1
- 1
docs/source/fastNLP.core.samplers.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.samplers.conversion_utils
fastNLP.core.samplers.mix_sampler


+ 1
- 1
docs/source/fastNLP.core.utils.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core.utils.cache_results
fastNLP.core.utils.dummy_class


+ 1
- 1
docs/source/fastNLP.envs.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.envs.distributed
fastNLP.envs.env


+ 1
- 1
docs/source/fastNLP.io.loader.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.io.loader.classification
fastNLP.io.loader.conll


+ 1
- 1
docs/source/fastNLP.io.pipe.rst View File

@@ -10,7 +10,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.io.pipe.classification
fastNLP.io.pipe.conll


+ 2
- 2
docs/source/fastNLP.io.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.io.loader
fastNLP.io.pipe
@@ -19,7 +19,7 @@ Submodules
----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.io.data_bundle
fastNLP.io.embed_loader


+ 1
- 1
docs/source/fastNLP.rst View File

@@ -10,7 +10,7 @@ Subpackages
-----------

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP.core
fastNLP.envs


+ 1
- 1
docs/source/modules.rst View File

@@ -2,6 +2,6 @@ fastNLP
=======

.. toctree::
:maxdepth: 6
:maxdepth: 4

fastNLP

+ 2
- 2
fastNLP/core/callbacks/utils.py View File

@@ -8,8 +8,8 @@ from fastNLP.core.utils.utils import _get_fun_msg

def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->Tuple[str, float]:
"""
res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行
匹配。
``res`` 中寻找 ``monitor`` 并返回。如果 ``monitor`` 没找到则尝试用 ``_real_monitor`` ,若 ``_real_monitor`` 为 ``None``
则尝试使用 ``monitor`` 的值进行匹配。

:param monitor:
:param real_monitor:


+ 2
- 2
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -162,9 +162,9 @@ class PaddleDataLoader(DataLoader):

def get_batch_indices(self) -> List[int]:
"""
获取当前 batch 的 idx
获取当前 ``batch`` 中每条数据对应的索引。

:return:
:return: 当前 ``batch`` 数据的索引
"""
return self.cur_batch_indices



+ 2
- 2
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -170,9 +170,9 @@ class TorchDataLoader(DataLoader):

def get_batch_indices(self) -> List[int]:
"""
获取当前 batch 的 idx
获取当前 ``batch`` 中每条数据对应的索引。

:return:
:return: 当前 ``batch`` 数据的索引
"""
return self.cur_batch_indices



+ 10
- 9
fastNLP/core/dataset/dataset.py View File

@@ -400,15 +400,16 @@ class DataSet:
new_field_name: str = None, num_proc: int = 0,
progress_desc: str = None, show_progress_bar: bool = True):
r"""
将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。

:param field_name: 传入 func 的是哪个 field。
:param func: input是 instance 中名为 `field_name` 的 field 的内容。
:param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆
盖之前的 field。如果为 None 则不创建新的 field。
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param progress_desc: progress_desc 的值,默认为 Main
:param show_progress_bar: 是否展示进度条,默认展示进度条
将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并获取函数的返回值。

:param field_name: 传入 ``func`` 的 ``field`` 名称。
:param func: 一个函数,其输入是 ``instance`` 中名为 ``field_name`` 的 ``field`` 的内容。
:param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 对应的 ``field`` 中,如果名称与已有的 ``field`` 相同
则进行覆盖。如果为 ``None`` 则不会覆盖和创建 ``field`` 。
:param num_proc: 使用进程的数量。请注意,由于 ``python`` 语言的特性,使用了多少进程就会导致多少倍内存的增长。
:param progress_desc: 进度条的描述字符,默认为 ``Main``。
:param show_progress_bar: 是否展示进度条;默认为展示。
:return: 从函数 ``func`` 中得到的返回值。
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if not self.has_field(field_name=field_name):


+ 1
- 1
fastNLP/core/utils/dummy_class.py View File

@@ -1,4 +1,4 @@
import functools
__all__ = []

class DummyClass:
def __init__(self, *args, **kwargs):


+ 2
- 3
fastNLP/core/utils/rich_progress.py View File

@@ -1,7 +1,6 @@
"""
该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能
不冲突

该文件用于为 ``fastNLP`` 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中
的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突
"""
import sys
from typing import Any, Union, Optional


+ 84
- 75
fastNLP/core/utils/utils.py View File

@@ -10,10 +10,6 @@ from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence
from typing import Tuple, Optional
from time import sleep

try:
from typing import Literal, Final
except ImportError:
from typing_extensions import Literal, Final
import os
from contextlib import contextmanager
from functools import wraps
@@ -22,7 +18,6 @@ import numpy as np
from pathlib import Path

from fastNLP.core.log import logger
from ...envs import SUPPORT_BACKENDS


__all__ = [
@@ -43,10 +38,10 @@ __all__ = [

def get_fn_arg_names(fn: Callable) -> List[str]:
r"""
返回一个函数所有参数的名字
返回一个函数所有参数的名字

:param fn: 需要查询的函数
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字;
:param fn: 需要查询的函数
:return: 一个列表,其中的元素是函数 ``fn`` 参数的字符串名字
"""
return list(inspect.signature(fn).parameters)

@@ -54,24 +49,18 @@ def get_fn_arg_names(fn: Callable) -> List[str]:
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
r"""
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。
该函数会根据输入函数的形参名从 ``*args`` (因此都需要是 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过
``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为
``value`` 的参数。

1.该函数用来提供给用户根据字符串匹配从而实现自动调用;
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来;
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性;
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同;

:param fn: 用来进行实际计算的函数,其参数可以包含有默认值;
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数;
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取
参数值后,再传给 `fn` 进行实际的运算;
:param mapping: 一个字典,用来更改其前面的字典的键值;

:return: 返回 `fn` 运行的结果;
1. 该函数用来提供给用户根据字符串匹配从而实现自动调用;
2. 注意 ``mapping`` 默认为 ``None``,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 ``mapping`` 为一个字典传入进来;
如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性;
3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;
4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同;

Examples::

>>> # 1
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred);
>>> batch = {"x": 20, "y": 1}
@@ -84,6 +73,14 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240

:param fn: 用来进行实际计算的函数,其参数可以包含有默认值;
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数;
:param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取
参数值后,再传给 ``fn`` 进行实际的运算;
:param mapping: 一个字典,用来更改其前面的字典的键值;

:return: 返回 ``fn`` 运行的结果;
"""

if signature_fn is not None:
@@ -226,13 +223,13 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):

def check_user_specific_params(user_params: Dict, fn: Callable):
"""
该函数使用用户的输入来对指定函数的参数进行赋值;
主要用于一些用户无法直接调用函数的情况;
该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误;
该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况;
该函数主要的作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误;

:param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值;
:param fn: 会被调用的函数;
:return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值;
:param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字,
``value`` 为每一个参数的值;
:param fn: 将要被调用的函数;
:return: 返回一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值;
"""

fn_arg_names = get_fn_arg_names(fn)
@@ -243,6 +240,9 @@ def check_user_specific_params(user_params: Dict, fn: Callable):


def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict:
"""
将传入的 `dataclass` 实例转换为字典。
"""
if not is_dataclass(data):
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.")
_dict = dict()
@@ -253,21 +253,31 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict:

def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any:
r"""
用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能;
该函数应用于 `input_mapping` 和 `output_mapping`;
对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用;
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用;
用来实现将输入的 ``batch``,或者输出的 ``outputs``,通过 ``mapping`` 将键值进行更换的功能;
该函数应用于 ``input_mapping`` 和 ``output_mapping``;

转换的逻辑按优先级依次为:
对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用;
对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step`
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用;

1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`;
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`];
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key];
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换;
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。
转换的逻辑按优先级依次为:

:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。
1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``;
2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``;
* 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``;
* 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换;
* 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典::
{
"_0": list[0],
"_1": list[1],
...
}

然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。

:param mapping: 用于转换的字典或者函数;``mapping`` 是函数时,返回值必须为字典类型。
:param data: 需要被转换的对象;
:return: 返回转换好的结果;
"""
@@ -320,21 +330,20 @@ def apply_to_collection(
include_none: bool = True,
**kwargs: Any,
) -> Any:
"""将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。

this function credit to: https://github.com/PyTorchLightning/pytorch-lightning
Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
wrong_dtype: the given function won't be applied if this type is specified and the given collections
is of the ``wrong_dtype`` even if it is of type ``dtype``
include_none: Whether to include an element if the output of ``function`` is ``None``.
**kwargs: keyword arguments (will be forwarded to calls of ``function``)

Returns:
The resulting collection
"""
使用函数 ``function`` 递归地在 ``data`` 中的元素执行,但是仅在满足元素为 ``dtype`` 时执行。

该函数参考了 `pytorch-lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ 的实现

:param data: 需要进行处理的数据集合或数据
:param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据
:param function: 对数据进行处理的函数
:param args: ``function`` 所需要的其它参数
:param wrong_dtype: ``function`` 一定不会生效的数据类型。如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型
那么也不会生效。
:param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``。
:param kwargs: ``function`` 所需要的其它参数
:return: 经过 ``function`` 处理后的数据集合
"""
# Breaking condition
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
@@ -402,16 +411,18 @@ def apply_to_collection(
@contextmanager
def nullcontext():
r"""
用来实现一个什么 dummy 的 context 上下文环境;
实现一个什么都不做的上下文环境
"""
yield


def sub_column(string: str, c: int, c_size: int, title: str) -> str:
r"""
对传入的字符串进行截断,方便在命令行中显示

:param string: 要被截断的字符串
:param c: 命令行列数
:param c_size: instance或dataset field数
:param c_size: :class:`~fastNLP.core.Instance` 或 :class:`fastNLP.core.DataSet` 的 ``field`` 数目
:param title: 列名
:return: 对一个过长的列进行截断的结果
"""
@@ -442,18 +453,17 @@ def _is_iterable(value):

def pretty_table_printer(dataset_or_ins) -> PrettyTable:
r"""
:param dataset_or_ins: 传入一个dataSet或者instance

.. code-block::
在 ``fastNLP`` 中展示数据的函数::

ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
>>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
+-----------+-----------+-----------------+
| field_1 | field_2 | field_3 |
+-----------+-----------+-----------------+
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
+-----------+-----------+-----------------+

:return: 以 pretty table的形式返回根据terminal大小进行自动截断
:param dataset_or_ins: 要展示的 :class:`~fastNLP.core.DataSet` 或者 :class:`~fastNLP.core.Instance`
:return: 根据 ``terminal`` 大小进行自动截断的数据表格
"""
x = PrettyTable()
try:
@@ -486,7 +496,7 @@ def pretty_table_printer(dataset_or_ins) -> PrettyTable:


class Option(dict):
r"""a dict can treat keys as attributes"""
r"""将键转化为属性的字典类型"""

def __getattr__(self, item):
try:
@@ -516,11 +526,10 @@ _emitted_deprecation_warnings = set()


def deprecated(help_message: Optional[str] = None):
"""Decorator to mark a function as deprecated.
"""
标记当前功能已经过时的装饰器。

Args:
help_message (`Optional[str]`): An optional message to guide the user on how to
switch to non-deprecated usage of the library.
:param help_message: 一段指引信息,告知用户如何将代码切换为当前版本提倡的用法。
"""

def decorator(deprecated_function: Callable):
@@ -549,11 +558,10 @@ def deprecated(help_message: Optional[str] = None):
return decorator


def seq_len_to_mask(seq_len, max_len=None):
def seq_len_to_mask(seq_len, max_len: Optional[int]):
r"""

将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。
转变 1-d seq_len到2-d mask.
将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。

.. code-block::

@@ -570,10 +578,11 @@ def seq_len_to_mask(seq_len, max_len=None):
>>>print(mask.size())
torch.Size([14, 100])

:param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8
:param seq_len: 大小为是 ``(B,)`` 的长度序列
:param int max_len: 将长度 ``pad`` 到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度。
但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入
一个 ``max_len`` 使得 ``mask`` 的长度 ``pad`` 到该长度。
:return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8``
"""
if isinstance(seq_len, np.ndarray):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."


+ 5
- 3
fastNLP/envs/utils.py View File

@@ -6,6 +6,7 @@ from packaging.version import Version
import subprocess
import pkg_resources

__all__ = []

def _module_available(module_path: str) -> bool:
"""Check if a path is available in your environment.
@@ -48,10 +49,11 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
pkg_version = Version(pkg_version.base_version)
return op(pkg_version, Version(version))

def get_gpu_count():
def get_gpu_count() -> int:
"""
利用命令行获取gpu数目的函数
:return: gpu数目,如果没有显卡设备则为-1
利用命令行获取 ``gpu`` 数目的函数

:return: 显卡数目,如果没有显卡设备则为-1
"""
try:
lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv'])


Loading…
Cancel
Save