@@ -16,7 +16,7 @@ help: | |||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) | ||||
apidoc: | apidoc: | ||||
$(SPHINXAPIDOC) -efM -d 6 -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE) | |||||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE) | |||||
server: | server: | ||||
cd build/html && python -m http.server $(PORT) | cd build/html && python -m http.server $(PORT) | ||||
@@ -24,6 +24,9 @@ server: | |||||
delete: | delete: | ||||
rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build | rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build | ||||
web: | |||||
make html && make server | |||||
dev: | dev: | ||||
make delete && make apidoc && make html && make server | make delete && make apidoc && make html && make server | ||||
@@ -42,7 +42,8 @@ extensions = [ | |||||
'sphinx.ext.viewcode', | 'sphinx.ext.viewcode', | ||||
'sphinx.ext.autosummary', | 'sphinx.ext.autosummary', | ||||
'sphinx.ext.mathjax', | 'sphinx.ext.mathjax', | ||||
'sphinx.ext.todo' | |||||
'sphinx.ext.todo', | |||||
'sphinx_autodoc_typehints' | |||||
] | ] | ||||
autodoc_default_options = { | autodoc_default_options = { | ||||
@@ -53,8 +54,10 @@ autodoc_default_options = { | |||||
add_module_names = False | add_module_names = False | ||||
autosummary_ignore_module_all = False | autosummary_ignore_module_all = False | ||||
autodoc_typehints = "description" | |||||
# autodoc_typehints = "description" | |||||
autoclass_content = "class" | autoclass_content = "class" | ||||
typehints_fully_qualified = False | |||||
typehints_defaults = "comma" | |||||
# Add any paths that contain templates here, relative to this directory. | # Add any paths that contain templates here, relative to this directory. | ||||
templates_path = ['_templates'] | templates_path = ['_templates'] | ||||
@@ -168,8 +171,8 @@ texinfo_documents = [ | |||||
# -- Extension configuration ------------------------------------------------- | # -- Extension configuration ------------------------------------------------- | ||||
def maybe_skip_member(app, what, name, obj, skip, options): | def maybe_skip_member(app, what, name, obj, skip, options): | ||||
# if obj.__doc__ is None: | |||||
# return True | |||||
if obj.__doc__ is None: | |||||
return True | |||||
if name == "__init__": | if name == "__init__": | ||||
return False | return False | ||||
if name.startswith("_"): | if name.startswith("_"): | ||||
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.callbacks.torch_callbacks | fastNLP.core.callbacks.torch_callbacks | ||||
@@ -18,7 +18,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.callbacks.callback | fastNLP.core.callbacks.callback | ||||
fastNLP.core.callbacks.callback_event | fastNLP.core.callbacks.callback_event | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback | fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback | ||||
fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback | fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback |
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.collators.padders.exceptions | fastNLP.core.collators.padders.exceptions | ||||
fastNLP.core.collators.padders.get_padder | fastNLP.core.collators.padders.get_padder | ||||
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.collators.padders | fastNLP.core.collators.padders | ||||
@@ -18,7 +18,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.collators.collator | fastNLP.core.collators.collator | ||||
fastNLP.core.collators.packer_unpacker | fastNLP.core.collators.packer_unpacker |
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.controllers.loops.evaluate_batch_loop | fastNLP.core.controllers.loops.evaluate_batch_loop | ||||
fastNLP.core.controllers.loops.loop | fastNLP.core.controllers.loops.loop | ||||
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.controllers.loops | fastNLP.core.controllers.loops | ||||
fastNLP.core.controllers.utils | fastNLP.core.controllers.utils | ||||
@@ -19,7 +19,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.controllers.evaluator | fastNLP.core.controllers.evaluator | ||||
fastNLP.core.controllers.trainer | fastNLP.core.controllers.trainer |
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.controllers.utils.state | fastNLP.core.controllers.utils.state | ||||
fastNLP.core.controllers.utils.utils | fastNLP.core.controllers.utils.utils |
@@ -10,6 +10,6 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.dataloaders.jittor_dataloader.fdl | fastNLP.core.dataloaders.jittor_dataloader.fdl |
@@ -10,6 +10,6 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.dataloaders.paddle_dataloader.fdl | fastNLP.core.dataloaders.paddle_dataloader.fdl |
@@ -0,0 +1,7 @@ | |||||
fastNLP.core.dataloaders.prepare\_dataloader module | |||||
=================================================== | |||||
.. automodule:: fastNLP.core.dataloaders.prepare_dataloader | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.dataloaders.jittor_dataloader | fastNLP.core.dataloaders.jittor_dataloader | ||||
fastNLP.core.dataloaders.paddle_dataloader | fastNLP.core.dataloaders.paddle_dataloader | ||||
@@ -20,7 +20,8 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.dataloaders.mix_dataloader | fastNLP.core.dataloaders.mix_dataloader | ||||
fastNLP.core.dataloaders.prepare_dataloader | |||||
fastNLP.core.dataloaders.utils | fastNLP.core.dataloaders.utils |
@@ -10,6 +10,6 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.dataloaders.torch_dataloader.fdl | fastNLP.core.dataloaders.torch_dataloader.fdl |
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.dataset.dataset | fastNLP.core.dataset.dataset | ||||
fastNLP.core.dataset.field | fastNLP.core.dataset.field | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.drivers.jittor_driver.initialize_jittor_driver | fastNLP.core.drivers.jittor_driver.initialize_jittor_driver | ||||
fastNLP.core.drivers.jittor_driver.jittor_driver | fastNLP.core.drivers.jittor_driver.jittor_driver | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.drivers.paddle_driver.dist_utils | fastNLP.core.drivers.paddle_driver.dist_utils | ||||
fastNLP.core.drivers.paddle_driver.fleet | fastNLP.core.drivers.paddle_driver.fleet | ||||
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.drivers.jittor_driver | fastNLP.core.drivers.jittor_driver | ||||
fastNLP.core.drivers.paddle_driver | fastNLP.core.drivers.paddle_driver | ||||
@@ -20,7 +20,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.drivers.choose_driver | fastNLP.core.drivers.choose_driver | ||||
fastNLP.core.drivers.driver | fastNLP.core.drivers.driver | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.drivers.torch_driver.ddp | fastNLP.core.drivers.torch_driver.ddp | ||||
fastNLP.core.drivers.torch_driver.dist_utils | fastNLP.core.drivers.torch_driver.dist_utils | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.log.handler | fastNLP.core.log.handler | ||||
fastNLP.core.log.highlighter | fastNLP.core.log.highlighter | ||||
@@ -10,6 +10,6 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.metrics.backend.jittor_backend.backend | fastNLP.core.metrics.backend.jittor_backend.backend |
@@ -10,6 +10,6 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.metrics.backend.paddle_backend.backend | fastNLP.core.metrics.backend.paddle_backend.backend |
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.metrics.backend.jittor_backend | fastNLP.core.metrics.backend.jittor_backend | ||||
fastNLP.core.metrics.backend.paddle_backend | fastNLP.core.metrics.backend.paddle_backend | ||||
@@ -20,7 +20,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.metrics.backend.auto_backend | fastNLP.core.metrics.backend.auto_backend | ||||
fastNLP.core.metrics.backend.backend | fastNLP.core.metrics.backend.backend |
@@ -10,6 +10,6 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.metrics.backend.torch_backend.backend | fastNLP.core.metrics.backend.torch_backend.backend |
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.metrics.backend | fastNLP.core.metrics.backend | ||||
@@ -18,7 +18,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.metrics.accuracy | fastNLP.core.metrics.accuracy | ||||
fastNLP.core.metrics.classify_f1_pre_rec_metric | fastNLP.core.metrics.classify_f1_pre_rec_metric | ||||
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.callbacks | fastNLP.core.callbacks | ||||
fastNLP.core.collators | fastNLP.core.collators | ||||
@@ -27,6 +27,6 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.vocabulary | fastNLP.core.vocabulary |
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.samplers.conversion_utils | fastNLP.core.samplers.conversion_utils | ||||
fastNLP.core.samplers.mix_sampler | fastNLP.core.samplers.mix_sampler | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core.utils.cache_results | fastNLP.core.utils.cache_results | ||||
fastNLP.core.utils.dummy_class | fastNLP.core.utils.dummy_class | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.envs.distributed | fastNLP.envs.distributed | ||||
fastNLP.envs.env | fastNLP.envs.env | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.io.loader.classification | fastNLP.io.loader.classification | ||||
fastNLP.io.loader.conll | fastNLP.io.loader.conll | ||||
@@ -10,7 +10,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.io.pipe.classification | fastNLP.io.pipe.classification | ||||
fastNLP.io.pipe.conll | fastNLP.io.pipe.conll | ||||
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.io.loader | fastNLP.io.loader | ||||
fastNLP.io.pipe | fastNLP.io.pipe | ||||
@@ -19,7 +19,7 @@ Submodules | |||||
---------- | ---------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.io.data_bundle | fastNLP.io.data_bundle | ||||
fastNLP.io.embed_loader | fastNLP.io.embed_loader | ||||
@@ -10,7 +10,7 @@ Subpackages | |||||
----------- | ----------- | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP.core | fastNLP.core | ||||
fastNLP.envs | fastNLP.envs | ||||
@@ -2,6 +2,6 @@ fastNLP | |||||
======= | ======= | ||||
.. toctree:: | .. toctree:: | ||||
:maxdepth: 6 | |||||
:maxdepth: 4 | |||||
fastNLP | fastNLP |
@@ -8,8 +8,8 @@ from fastNLP.core.utils.utils import _get_fun_msg | |||||
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->Tuple[str, float]: | def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->Tuple[str, float]: | ||||
""" | """ | ||||
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 | |||||
匹配。 | |||||
从 ``res`` 中寻找 ``monitor`` 并返回。如果 ``monitor`` 没找到则尝试用 ``_real_monitor`` ,若 ``_real_monitor`` 为 ``None`` | |||||
则尝试使用 ``monitor`` 的值进行匹配。 | |||||
:param monitor: | :param monitor: | ||||
:param real_monitor: | :param real_monitor: | ||||
@@ -162,9 +162,9 @@ class PaddleDataLoader(DataLoader): | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前 batch 的 idx | |||||
获取当前 ``batch`` 中每条数据对应的索引。 | |||||
:return: | |||||
:return: 当前 ``batch`` 数据的索引 | |||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
@@ -170,9 +170,9 @@ class TorchDataLoader(DataLoader): | |||||
def get_batch_indices(self) -> List[int]: | def get_batch_indices(self) -> List[int]: | ||||
""" | """ | ||||
获取当前 batch 的 idx | |||||
获取当前 ``batch`` 中每条数据对应的索引。 | |||||
:return: | |||||
:return: 当前 ``batch`` 数据的索引 | |||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
@@ -400,15 +400,16 @@ class DataSet: | |||||
new_field_name: str = None, num_proc: int = 0, | new_field_name: str = None, num_proc: int = 0, | ||||
progress_desc: str = None, show_progress_bar: bool = True): | progress_desc: str = None, show_progress_bar: bool = True): | ||||
r""" | r""" | ||||
将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 | |||||
:param field_name: 传入 func 的是哪个 field。 | |||||
:param func: input是 instance 中名为 `field_name` 的 field 的内容。 | |||||
:param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 | |||||
盖之前的 field。如果为 None 则不创建新的 field。 | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||||
:param progress_desc: progress_desc 的值,默认为 Main | |||||
:param show_progress_bar: 是否展示进度条,默认展示进度条 | |||||
将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并获取函数的返回值。 | |||||
:param field_name: 传入 ``func`` 的 ``field`` 名称。 | |||||
:param func: 一个函数,其输入是 ``instance`` 中名为 ``field_name`` 的 ``field`` 的内容。 | |||||
:param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 对应的 ``field`` 中,如果名称与已有的 ``field`` 相同 | |||||
则进行覆盖。如果为 ``None`` 则不会覆盖和创建 ``field`` 。 | |||||
:param num_proc: 使用进程的数量。请注意,由于 ``python`` 语言的特性,使用了多少进程就会导致多少倍内存的增长。 | |||||
:param progress_desc: 进度条的描述字符,默认为 ``Main``。 | |||||
:param show_progress_bar: 是否展示进度条;默认为展示。 | |||||
:return: 从函数 ``func`` 中得到的返回值。 | |||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
if not self.has_field(field_name=field_name): | if not self.has_field(field_name=field_name): | ||||
@@ -1,4 +1,4 @@ | |||||
import functools | |||||
__all__ = [] | |||||
class DummyClass: | class DummyClass: | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
@@ -1,7 +1,6 @@ | |||||
""" | """ | ||||
该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能 | |||||
不冲突 | |||||
该文件用于为 ``fastNLP`` 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中 | |||||
的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突 | |||||
""" | """ | ||||
import sys | import sys | ||||
from typing import Any, Union, Optional | from typing import Any, Union, Optional | ||||
@@ -10,10 +10,6 @@ from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence | |||||
from typing import Tuple, Optional | from typing import Tuple, Optional | ||||
from time import sleep | from time import sleep | ||||
try: | |||||
from typing import Literal, Final | |||||
except ImportError: | |||||
from typing_extensions import Literal, Final | |||||
import os | import os | ||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
from functools import wraps | from functools import wraps | ||||
@@ -22,7 +18,6 @@ import numpy as np | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from ...envs import SUPPORT_BACKENDS | |||||
__all__ = [ | __all__ = [ | ||||
@@ -43,10 +38,10 @@ __all__ = [ | |||||
def get_fn_arg_names(fn: Callable) -> List[str]: | def get_fn_arg_names(fn: Callable) -> List[str]: | ||||
r""" | r""" | ||||
返回一个函数的所有参数的名字; | |||||
返回一个函数所有参数的名字 | |||||
:param fn: 需要查询的函数; | |||||
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | |||||
:param fn: 需要查询的函数 | |||||
:return: 一个列表,其中的元素是函数 ``fn`` 参数的字符串名字 | |||||
""" | """ | ||||
return list(inspect.signature(fn).parameters) | return list(inspect.signature(fn).parameters) | ||||
@@ -54,24 +49,18 @@ def get_fn_arg_names(fn: Callable) -> List[str]: | |||||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | ||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | ||||
r""" | r""" | ||||
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping | |||||
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||||
该函数会根据输入函数的形参名从 ``*args`` (因此都需要是 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过 | |||||
``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为 | |||||
``value`` 的参数。 | |||||
1.该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||||
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; | |||||
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | |||||
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; | |||||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||||
参数值后,再传给 `fn` 进行实际的运算; | |||||
:param mapping: 一个字典,用来更改其前面的字典的键值; | |||||
:return: 返回 `fn` 运行的结果; | |||||
1. 该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||||
2. 注意 ``mapping`` 默认为 ``None``,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 ``mapping`` 为一个字典传入进来; | |||||
如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性; | |||||
3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||||
4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同; | |||||
Examples:: | Examples:: | ||||
>>> # 1 | >>> # 1 | ||||
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); | >>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); | ||||
>>> batch = {"x": 20, "y": 1} | >>> batch = {"x": 20, "y": 1} | ||||
@@ -84,6 +73,14 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 | >>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 | ||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | ||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | ||||
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | |||||
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数; | |||||
:param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||||
参数值后,再传给 ``fn`` 进行实际的运算; | |||||
:param mapping: 一个字典,用来更改其前面的字典的键值; | |||||
:return: 返回 ``fn`` 运行的结果; | |||||
""" | """ | ||||
if signature_fn is not None: | if signature_fn is not None: | ||||
@@ -226,13 +223,13 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): | |||||
def check_user_specific_params(user_params: Dict, fn: Callable): | def check_user_specific_params(user_params: Dict, fn: Callable): | ||||
""" | """ | ||||
该函数使用用户的输入来对指定函数的参数进行赋值; | |||||
主要用于一些用户无法直接调用函数的情况; | |||||
该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误; | |||||
该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况; | |||||
该函数主要的作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误; | |||||
:param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值; | |||||
:param fn: 会被调用的函数; | |||||
:return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值; | |||||
:param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字, | |||||
``value`` 为每一个参数的值; | |||||
:param fn: 将要被调用的函数; | |||||
:return: 返回一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值; | |||||
""" | """ | ||||
fn_arg_names = get_fn_arg_names(fn) | fn_arg_names = get_fn_arg_names(fn) | ||||
@@ -243,6 +240,9 @@ def check_user_specific_params(user_params: Dict, fn: Callable): | |||||
def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: | def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: | ||||
""" | |||||
将传入的 `dataclass` 实例转换为字典。 | |||||
""" | |||||
if not is_dataclass(data): | if not is_dataclass(data): | ||||
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | ||||
_dict = dict() | _dict = dict() | ||||
@@ -253,21 +253,31 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: | |||||
def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | ||||
r""" | r""" | ||||
用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; | |||||
该函数应用于 `input_mapping` 和 `output_mapping`; | |||||
对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; | |||||
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; | |||||
用来实现将输入的 ``batch``,或者输出的 ``outputs``,通过 ``mapping`` 将键值进行更换的功能; | |||||
该函数应用于 ``input_mapping`` 和 ``output_mapping``; | |||||
转换的逻辑按优先级依次为: | |||||
对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用; | |||||
对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step` | |||||
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用; | |||||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; | |||||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; | |||||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; | |||||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; | |||||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 | |||||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 | |||||
转换的逻辑按优先级依次为: | |||||
:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 | |||||
1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``; | |||||
2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``; | |||||
* 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``; | |||||
* 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换; | |||||
* 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典:: | |||||
{ | |||||
"_0": list[0], | |||||
"_1": list[1], | |||||
... | |||||
} | |||||
然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。 | |||||
:param mapping: 用于转换的字典或者函数;``mapping`` 是函数时,返回值必须为字典类型。 | |||||
:param data: 需要被转换的对象; | :param data: 需要被转换的对象; | ||||
:return: 返回转换好的结果; | :return: 返回转换好的结果; | ||||
""" | """ | ||||
@@ -320,21 +330,20 @@ def apply_to_collection( | |||||
include_none: bool = True, | include_none: bool = True, | ||||
**kwargs: Any, | **kwargs: Any, | ||||
) -> Any: | ) -> Any: | ||||
"""将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 | |||||
this function credit to: https://github.com/PyTorchLightning/pytorch-lightning | |||||
Args: | |||||
data: the collection to apply the function to | |||||
dtype: the given function will be applied to all elements of this dtype | |||||
function: the function to apply | |||||
*args: positional arguments (will be forwarded to calls of ``function``) | |||||
wrong_dtype: the given function won't be applied if this type is specified and the given collections | |||||
is of the ``wrong_dtype`` even if it is of type ``dtype`` | |||||
include_none: Whether to include an element if the output of ``function`` is ``None``. | |||||
**kwargs: keyword arguments (will be forwarded to calls of ``function``) | |||||
Returns: | |||||
The resulting collection | |||||
""" | |||||
使用函数 ``function`` 递归地在 ``data`` 中的元素执行,但是仅在满足元素为 ``dtype`` 时执行。 | |||||
该函数参考了 `pytorch-lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ 的实现 | |||||
:param data: 需要进行处理的数据集合或数据 | |||||
:param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据 | |||||
:param function: 对数据进行处理的函数 | |||||
:param args: ``function`` 所需要的其它参数 | |||||
:param wrong_dtype: ``function`` 一定不会生效的数据类型。如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型 | |||||
那么也不会生效。 | |||||
:param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``。 | |||||
:param kwargs: ``function`` 所需要的其它参数 | |||||
:return: 经过 ``function`` 处理后的数据集合 | |||||
""" | """ | ||||
# Breaking condition | # Breaking condition | ||||
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): | if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): | ||||
@@ -402,16 +411,18 @@ def apply_to_collection( | |||||
@contextmanager | @contextmanager | ||||
def nullcontext(): | def nullcontext(): | ||||
r""" | r""" | ||||
用来实现一个什么 dummy 的 context 上下文环境; | |||||
实现一个什么都不做的上下文环境 | |||||
""" | """ | ||||
yield | yield | ||||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | def sub_column(string: str, c: int, c_size: int, title: str) -> str: | ||||
r""" | r""" | ||||
对传入的字符串进行截断,方便在命令行中显示 | |||||
:param string: 要被截断的字符串 | :param string: 要被截断的字符串 | ||||
:param c: 命令行列数 | :param c: 命令行列数 | ||||
:param c_size: instance或dataset field数 | |||||
:param c_size: :class:`~fastNLP.core.Instance` 或 :class:`fastNLP.core.DataSet` 的 ``field`` 数目 | |||||
:param title: 列名 | :param title: 列名 | ||||
:return: 对一个过长的列进行截断的结果 | :return: 对一个过长的列进行截断的结果 | ||||
""" | """ | ||||
@@ -442,18 +453,17 @@ def _is_iterable(value): | |||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | def pretty_table_printer(dataset_or_ins) -> PrettyTable: | ||||
r""" | r""" | ||||
:param dataset_or_ins: 传入一个dataSet或者instance | |||||
.. code-block:: | |||||
在 ``fastNLP`` 中展示数据的函数:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||||
>>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||||
+-----------+-----------+-----------------+ | +-----------+-----------+-----------------+ | ||||
| field_1 | field_2 | field_3 | | | field_1 | field_2 | field_3 | | ||||
+-----------+-----------+-----------------+ | +-----------+-----------+-----------------+ | ||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | ||||
+-----------+-----------+-----------------+ | +-----------+-----------+-----------------+ | ||||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||||
:param dataset_or_ins: 要展示的 :class:`~fastNLP.core.DataSet` 或者 :class:`~fastNLP.core.Instance` | |||||
:return: 根据 ``terminal`` 大小进行自动截断的数据表格 | |||||
""" | """ | ||||
x = PrettyTable() | x = PrettyTable() | ||||
try: | try: | ||||
@@ -486,7 +496,7 @@ def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||||
class Option(dict): | class Option(dict): | ||||
r"""a dict can treat keys as attributes""" | |||||
r"""将键转化为属性的字典类型""" | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
try: | try: | ||||
@@ -516,11 +526,10 @@ _emitted_deprecation_warnings = set() | |||||
def deprecated(help_message: Optional[str] = None): | def deprecated(help_message: Optional[str] = None): | ||||
"""Decorator to mark a function as deprecated. | |||||
""" | |||||
标记当前功能已经过时的装饰器。 | |||||
Args: | |||||
help_message (`Optional[str]`): An optional message to guide the user on how to | |||||
switch to non-deprecated usage of the library. | |||||
:param help_message: 一段指引信息,告知用户如何将代码切换为当前版本提倡的用法。 | |||||
""" | """ | ||||
def decorator(deprecated_function: Callable): | def decorator(deprecated_function: Callable): | ||||
@@ -549,11 +558,10 @@ def deprecated(help_message: Optional[str] = None): | |||||
return decorator | return decorator | ||||
def seq_len_to_mask(seq_len, max_len=None): | |||||
def seq_len_to_mask(seq_len, max_len: Optional[int]): | |||||
r""" | r""" | ||||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||||
转变 1-d seq_len到2-d mask. | |||||
将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。 | |||||
.. code-block:: | .. code-block:: | ||||
@@ -570,10 +578,11 @@ def seq_len_to_mask(seq_len, max_len=None): | |||||
>>>print(mask.size()) | >>>print(mask.size()) | ||||
torch.Size([14, 100]) | torch.Size([14, 100]) | ||||
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,) | |||||
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 | |||||
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 | |||||
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 | |||||
:param seq_len: 大小为是 ``(B,)`` 的长度序列 | |||||
:param int max_len: 将长度 ``pad`` 到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度。 | |||||
但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 | |||||
一个 ``max_len`` 使得 ``mask`` 的长度 ``pad`` 到该长度。 | |||||
:return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` | |||||
""" | """ | ||||
if isinstance(seq_len, np.ndarray): | if isinstance(seq_len, np.ndarray): | ||||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | ||||
@@ -6,6 +6,7 @@ from packaging.version import Version | |||||
import subprocess | import subprocess | ||||
import pkg_resources | import pkg_resources | ||||
__all__ = [] | |||||
def _module_available(module_path: str) -> bool: | def _module_available(module_path: str) -> bool: | ||||
"""Check if a path is available in your environment. | """Check if a path is available in your environment. | ||||
@@ -48,10 +49,11 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: | |||||
pkg_version = Version(pkg_version.base_version) | pkg_version = Version(pkg_version.base_version) | ||||
return op(pkg_version, Version(version)) | return op(pkg_version, Version(version)) | ||||
def get_gpu_count(): | |||||
def get_gpu_count() -> int: | |||||
""" | """ | ||||
利用命令行获取gpu数目的函数 | |||||
:return: gpu数目,如果没有显卡设备则为-1 | |||||
利用命令行获取 ``gpu`` 数目的函数 | |||||
:return: 显卡数目,如果没有显卡设备则为-1 | |||||
""" | """ | ||||
try: | try: | ||||
lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv']) | lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv']) | ||||