diff --git a/docs/Makefile b/docs/Makefile index 9f807ae2..d6c4f6b6 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -16,7 +16,7 @@ help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) apidoc: - $(SPHINXAPIDOC) -efM -d 6 -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE) + $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) $(SPHINXEXCLUDE) server: cd build/html && python -m http.server $(PORT) @@ -24,6 +24,9 @@ server: delete: rm -f source/$(SPHINXPROJ).* source/modules.rst && rm -rf build +web: + make html && make server + dev: make delete && make apidoc && make html && make server diff --git a/docs/source/conf.py b/docs/source/conf.py index 115448ed..2ed8ac96 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -42,7 +42,8 @@ extensions = [ 'sphinx.ext.viewcode', 'sphinx.ext.autosummary', 'sphinx.ext.mathjax', - 'sphinx.ext.todo' + 'sphinx.ext.todo', + 'sphinx_autodoc_typehints' ] autodoc_default_options = { @@ -53,8 +54,10 @@ autodoc_default_options = { add_module_names = False autosummary_ignore_module_all = False -autodoc_typehints = "description" +# autodoc_typehints = "description" autoclass_content = "class" +typehints_fully_qualified = False +typehints_defaults = "comma" # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -168,8 +171,8 @@ texinfo_documents = [ # -- Extension configuration ------------------------------------------------- def maybe_skip_member(app, what, name, obj, skip, options): - # if obj.__doc__ is None: - # return True + if obj.__doc__ is None: + return True if name == "__init__": return False if name.startswith("_"): diff --git a/docs/source/fastNLP.core.callbacks.rst b/docs/source/fastNLP.core.callbacks.rst index a3450110..0f3f93ac 100644 --- a/docs/source/fastNLP.core.callbacks.rst +++ b/docs/source/fastNLP.core.callbacks.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.callbacks.torch_callbacks @@ -18,7 +18,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.callbacks.callback fastNLP.core.callbacks.callback_event diff --git a/docs/source/fastNLP.core.callbacks.torch_callbacks.rst b/docs/source/fastNLP.core.callbacks.torch_callbacks.rst index 193f46d3..6f00f6f7 100644 --- a/docs/source/fastNLP.core.callbacks.torch_callbacks.rst +++ b/docs/source/fastNLP.core.callbacks.torch_callbacks.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.callbacks.torch_callbacks.torch_grad_clip_callback fastNLP.core.callbacks.torch_callbacks.torch_lr_sched_callback diff --git a/docs/source/fastNLP.core.collators.padders.rst b/docs/source/fastNLP.core.collators.padders.rst index 0ee61a26..6f40becb 100644 --- a/docs/source/fastNLP.core.collators.padders.rst +++ b/docs/source/fastNLP.core.collators.padders.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.collators.padders.exceptions fastNLP.core.collators.padders.get_padder diff --git a/docs/source/fastNLP.core.collators.rst b/docs/source/fastNLP.core.collators.rst index 1210e8b3..22259c12 100644 --- a/docs/source/fastNLP.core.collators.rst +++ b/docs/source/fastNLP.core.collators.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.collators.padders @@ -18,7 +18,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.collators.collator fastNLP.core.collators.packer_unpacker diff --git a/docs/source/fastNLP.core.controllers.loops.rst b/docs/source/fastNLP.core.controllers.loops.rst index 8db384f7..39879148 100644 --- a/docs/source/fastNLP.core.controllers.loops.rst +++ b/docs/source/fastNLP.core.controllers.loops.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.controllers.loops.evaluate_batch_loop fastNLP.core.controllers.loops.loop diff --git a/docs/source/fastNLP.core.controllers.rst b/docs/source/fastNLP.core.controllers.rst index daef5f3b..9440fbe4 100644 --- a/docs/source/fastNLP.core.controllers.rst +++ b/docs/source/fastNLP.core.controllers.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.controllers.loops fastNLP.core.controllers.utils @@ -19,7 +19,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.controllers.evaluator fastNLP.core.controllers.trainer diff --git a/docs/source/fastNLP.core.controllers.utils.rst b/docs/source/fastNLP.core.controllers.utils.rst index ca8a7307..f7bcc38c 100644 --- a/docs/source/fastNLP.core.controllers.utils.rst +++ b/docs/source/fastNLP.core.controllers.utils.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.controllers.utils.state fastNLP.core.controllers.utils.utils diff --git a/docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst b/docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst index d7a7a8dc..78d90c46 100644 --- a/docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst +++ b/docs/source/fastNLP.core.dataloaders.jittor_dataloader.rst @@ -10,6 +10,6 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.dataloaders.jittor_dataloader.fdl diff --git a/docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst b/docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst index 428a339e..dc4481d2 100644 --- a/docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst +++ b/docs/source/fastNLP.core.dataloaders.paddle_dataloader.rst @@ -10,6 +10,6 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.dataloaders.paddle_dataloader.fdl diff --git a/docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst b/docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst new file mode 100644 index 00000000..ac8c8c20 --- /dev/null +++ b/docs/source/fastNLP.core.dataloaders.prepare_dataloader.rst @@ -0,0 +1,7 @@ +fastNLP.core.dataloaders.prepare\_dataloader module +=================================================== + +.. automodule:: fastNLP.core.dataloaders.prepare_dataloader + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.dataloaders.rst b/docs/source/fastNLP.core.dataloaders.rst index a9bd51fa..e8c6b799 100644 --- a/docs/source/fastNLP.core.dataloaders.rst +++ b/docs/source/fastNLP.core.dataloaders.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.dataloaders.jittor_dataloader fastNLP.core.dataloaders.paddle_dataloader @@ -20,7 +20,8 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.dataloaders.mix_dataloader + fastNLP.core.dataloaders.prepare_dataloader fastNLP.core.dataloaders.utils diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst index f631571e..c9acca23 100644 --- a/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst +++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst @@ -10,6 +10,6 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.dataloaders.torch_dataloader.fdl diff --git a/docs/source/fastNLP.core.dataset.rst b/docs/source/fastNLP.core.dataset.rst index e3ceff77..dc36250a 100644 --- a/docs/source/fastNLP.core.dataset.rst +++ b/docs/source/fastNLP.core.dataset.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.dataset.dataset fastNLP.core.dataset.field diff --git a/docs/source/fastNLP.core.drivers.jittor_driver.rst b/docs/source/fastNLP.core.drivers.jittor_driver.rst index df32665b..7ec101c7 100644 --- a/docs/source/fastNLP.core.drivers.jittor_driver.rst +++ b/docs/source/fastNLP.core.drivers.jittor_driver.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.drivers.jittor_driver.initialize_jittor_driver fastNLP.core.drivers.jittor_driver.jittor_driver diff --git a/docs/source/fastNLP.core.drivers.paddle_driver.rst b/docs/source/fastNLP.core.drivers.paddle_driver.rst index 91038646..0f115eb5 100644 --- a/docs/source/fastNLP.core.drivers.paddle_driver.rst +++ b/docs/source/fastNLP.core.drivers.paddle_driver.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.drivers.paddle_driver.dist_utils fastNLP.core.drivers.paddle_driver.fleet diff --git a/docs/source/fastNLP.core.drivers.rst b/docs/source/fastNLP.core.drivers.rst index 3ac36f71..bb168c76 100644 --- a/docs/source/fastNLP.core.drivers.rst +++ b/docs/source/fastNLP.core.drivers.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.drivers.jittor_driver fastNLP.core.drivers.paddle_driver @@ -20,7 +20,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.drivers.choose_driver fastNLP.core.drivers.driver diff --git a/docs/source/fastNLP.core.drivers.torch_driver.rst b/docs/source/fastNLP.core.drivers.torch_driver.rst index 65f4ca1a..9a0109a2 100644 --- a/docs/source/fastNLP.core.drivers.torch_driver.rst +++ b/docs/source/fastNLP.core.drivers.torch_driver.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.drivers.torch_driver.ddp fastNLP.core.drivers.torch_driver.dist_utils diff --git a/docs/source/fastNLP.core.log.rst b/docs/source/fastNLP.core.log.rst index e52f9eb7..6cd67753 100644 --- a/docs/source/fastNLP.core.log.rst +++ b/docs/source/fastNLP.core.log.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.log.handler fastNLP.core.log.highlighter diff --git a/docs/source/fastNLP.core.metrics.backend.jittor_backend.rst b/docs/source/fastNLP.core.metrics.backend.jittor_backend.rst index 9b76aee3..6ce8b0d4 100644 --- a/docs/source/fastNLP.core.metrics.backend.jittor_backend.rst +++ b/docs/source/fastNLP.core.metrics.backend.jittor_backend.rst @@ -10,6 +10,6 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.metrics.backend.jittor_backend.backend diff --git a/docs/source/fastNLP.core.metrics.backend.paddle_backend.rst b/docs/source/fastNLP.core.metrics.backend.paddle_backend.rst index fb4ec69d..d932d4e5 100644 --- a/docs/source/fastNLP.core.metrics.backend.paddle_backend.rst +++ b/docs/source/fastNLP.core.metrics.backend.paddle_backend.rst @@ -10,6 +10,6 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.metrics.backend.paddle_backend.backend diff --git a/docs/source/fastNLP.core.metrics.backend.rst b/docs/source/fastNLP.core.metrics.backend.rst index 52ca7958..5a8cf4ad 100644 --- a/docs/source/fastNLP.core.metrics.backend.rst +++ b/docs/source/fastNLP.core.metrics.backend.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.metrics.backend.jittor_backend fastNLP.core.metrics.backend.paddle_backend @@ -20,7 +20,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.metrics.backend.auto_backend fastNLP.core.metrics.backend.backend diff --git a/docs/source/fastNLP.core.metrics.backend.torch_backend.rst b/docs/source/fastNLP.core.metrics.backend.torch_backend.rst index 07beae73..f01efe88 100644 --- a/docs/source/fastNLP.core.metrics.backend.torch_backend.rst +++ b/docs/source/fastNLP.core.metrics.backend.torch_backend.rst @@ -10,6 +10,6 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.metrics.backend.torch_backend.backend diff --git a/docs/source/fastNLP.core.metrics.rst b/docs/source/fastNLP.core.metrics.rst index e2770769..8ad6f729 100644 --- a/docs/source/fastNLP.core.metrics.rst +++ b/docs/source/fastNLP.core.metrics.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.metrics.backend @@ -18,7 +18,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.metrics.accuracy fastNLP.core.metrics.classify_f1_pre_rec_metric diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst index d71ffaf3..57dac16a 100644 --- a/docs/source/fastNLP.core.rst +++ b/docs/source/fastNLP.core.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.callbacks fastNLP.core.collators @@ -27,6 +27,6 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.vocabulary diff --git a/docs/source/fastNLP.core.samplers.rst b/docs/source/fastNLP.core.samplers.rst index f1b7be4c..9ccd9b59 100644 --- a/docs/source/fastNLP.core.samplers.rst +++ b/docs/source/fastNLP.core.samplers.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.samplers.conversion_utils fastNLP.core.samplers.mix_sampler diff --git a/docs/source/fastNLP.core.utils.rst b/docs/source/fastNLP.core.utils.rst index a63ed1db..2d682010 100644 --- a/docs/source/fastNLP.core.utils.rst +++ b/docs/source/fastNLP.core.utils.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core.utils.cache_results fastNLP.core.utils.dummy_class diff --git a/docs/source/fastNLP.envs.rst b/docs/source/fastNLP.envs.rst index 4c95ccfe..2e642ff7 100644 --- a/docs/source/fastNLP.envs.rst +++ b/docs/source/fastNLP.envs.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.envs.distributed fastNLP.envs.env diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst index 13bd5fe9..bd91b795 100644 --- a/docs/source/fastNLP.io.loader.rst +++ b/docs/source/fastNLP.io.loader.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.io.loader.classification fastNLP.io.loader.conll diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst index d8cf306e..9ad7e539 100644 --- a/docs/source/fastNLP.io.pipe.rst +++ b/docs/source/fastNLP.io.pipe.rst @@ -10,7 +10,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.io.pipe.classification fastNLP.io.pipe.conll diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index 4fab1696..5f025bba 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.io.loader fastNLP.io.pipe @@ -19,7 +19,7 @@ Submodules ---------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.io.data_bundle fastNLP.io.embed_loader diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index bee33e72..726eb9c6 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -10,7 +10,7 @@ Subpackages ----------- .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP.core fastNLP.envs diff --git a/docs/source/modules.rst b/docs/source/modules.rst index 5515520a..e9a92cb7 100644 --- a/docs/source/modules.rst +++ b/docs/source/modules.rst @@ -2,6 +2,6 @@ fastNLP ======= .. toctree:: - :maxdepth: 6 + :maxdepth: 4 fastNLP diff --git a/fastNLP/core/callbacks/utils.py b/fastNLP/core/callbacks/utils.py index c3f8275a..865a1fc7 100644 --- a/fastNLP/core/callbacks/utils.py +++ b/fastNLP/core/callbacks/utils.py @@ -8,8 +8,8 @@ from fastNLP.core.utils.utils import _get_fun_msg def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->Tuple[str, float]: """ - 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 - 匹配。 + 从 ``res`` 中寻找 ``monitor`` 并返回。如果 ``monitor`` 没找到则尝试用 ``_real_monitor`` ,若 ``_real_monitor`` 为 ``None`` + 则尝试使用 ``monitor`` 的值进行匹配。 :param monitor: :param real_monitor: diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 5c5e3bef..342a6c19 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -162,9 +162,9 @@ class PaddleDataLoader(DataLoader): def get_batch_indices(self) -> List[int]: """ - 获取当前 batch 的 idx + 获取当前 ``batch`` 中每条数据对应的索引。 - :return: + :return: 当前 ``batch`` 数据的索引 """ return self.cur_batch_indices diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 6a9e4af9..1f737467 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -170,9 +170,9 @@ class TorchDataLoader(DataLoader): def get_batch_indices(self) -> List[int]: """ - 获取当前 batch 的 idx + 获取当前 ``batch`` 中每条数据对应的索引。 - :return: + :return: 当前 ``batch`` 数据的索引 """ return self.cur_batch_indices diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 83e83ac9..7cba5a8c 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -400,15 +400,16 @@ class DataSet: new_field_name: str = None, num_proc: int = 0, progress_desc: str = None, show_progress_bar: bool = True): r""" - 将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 - - :param field_name: 传入 func 的是哪个 field。 - :param func: input是 instance 中名为 `field_name` 的 field 的内容。 - :param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 - 盖之前的 field。如果为 None 则不创建新的 field。 - :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param progress_desc: progress_desc 的值,默认为 Main - :param show_progress_bar: 是否展示进度条,默认展示进度条 + 将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并获取函数的返回值。 + + :param field_name: 传入 ``func`` 的 ``field`` 名称。 + :param func: 一个函数,其输入是 ``instance`` 中名为 ``field_name`` 的 ``field`` 的内容。 + :param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 对应的 ``field`` 中,如果名称与已有的 ``field`` 相同 + 则进行覆盖。如果为 ``None`` 则不会覆盖和创建 ``field`` 。 + :param num_proc: 使用进程的数量。请注意,由于 ``python`` 语言的特性,使用了多少进程就会导致多少倍内存的增长。 + :param progress_desc: 进度条的描述字符,默认为 ``Main``。 + :param show_progress_bar: 是否展示进度条;默认为展示。 + :return: 从函数 ``func`` 中得到的返回值。 """ assert len(self) != 0, "Null DataSet cannot use apply_field()." if not self.has_field(field_name=field_name): diff --git a/fastNLP/core/utils/dummy_class.py b/fastNLP/core/utils/dummy_class.py index 42200cbb..e7596607 100644 --- a/fastNLP/core/utils/dummy_class.py +++ b/fastNLP/core/utils/dummy_class.py @@ -1,4 +1,4 @@ -import functools +__all__ = [] class DummyClass: def __init__(self, *args, **kwargs): diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 4799765f..02a30c26 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -1,7 +1,6 @@ """ -该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能 - 不冲突 - +该文件用于为 ``fastNLP`` 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中 +的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突 """ import sys from typing import Any, Union, Optional diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index b07d8b82..ec7a8b47 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -10,10 +10,6 @@ from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence from typing import Tuple, Optional from time import sleep -try: - from typing import Literal, Final -except ImportError: - from typing_extensions import Literal, Final import os from contextlib import contextmanager from functools import wraps @@ -22,7 +18,6 @@ import numpy as np from pathlib import Path from fastNLP.core.log import logger -from ...envs import SUPPORT_BACKENDS __all__ = [ @@ -43,10 +38,10 @@ __all__ = [ def get_fn_arg_names(fn: Callable) -> List[str]: r""" - 返回一个函数的所有参数的名字; + 返回一个函数所有参数的名字 - :param fn: 需要查询的函数; - :return: 一个列表,其中的元素则是查询函数的参数的字符串名字; + :param fn: 需要查询的函数 + :return: 一个列表,其中的元素是函数 ``fn`` 参数的字符串名字 """ return list(inspect.signature(fn).parameters) @@ -54,24 +49,18 @@ def get_fn_arg_names(fn: Callable) -> List[str]: def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: r""" - 该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping - 参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 + 该函数会根据输入函数的形参名从 ``*args`` (因此都需要是 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过 + ``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为 + ``value`` 的参数。 - 1.该函数用来提供给用户根据字符串匹配从而实现自动调用; - 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; - 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; - 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; - 4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; - - :param fn: 用来进行实际计算的函数,其参数可以包含有默认值; - :param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; - :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 - 参数值后,再传给 `fn` 进行实际的运算; - :param mapping: 一个字典,用来更改其前面的字典的键值; - - :return: 返回 `fn` 运行的结果; + 1. 该函数用来提供给用户根据字符串匹配从而实现自动调用; + 2. 注意 ``mapping`` 默认为 ``None``,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 ``mapping`` 为一个字典传入进来; + 如果 ``mapping`` 不为 ``None``,那么我们一定会先使用 ``mapping`` 将输入的字典的 ``keys`` 修改过来,因此请务必亲自检查 ``mapping`` 的正确性; + 3. 如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; + 4. 如果输入的函数是一个 ``partial`` 函数,情况同第三点,即和默认参数的情况相同; Examples:: + >>> # 1 >>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); >>> batch = {"x": 20, "y": 1} @@ -84,6 +73,14 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None >>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 + + :param fn: 用来进行实际计算的函数,其参数可以包含有默认值; + :param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 ``fn`` 计算所需要的实际参数; + :param signature_fn: 函数,用来替换 ``fn`` 的函数签名,如果该参数不为 ``None``,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 + 参数值后,再传给 ``fn`` 进行实际的运算; + :param mapping: 一个字典,用来更改其前面的字典的键值; + + :return: 返回 ``fn`` 运行的结果; """ if signature_fn is not None: @@ -226,13 +223,13 @@ def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None): def check_user_specific_params(user_params: Dict, fn: Callable): """ - 该函数使用用户的输入来对指定函数的参数进行赋值; - 主要用于一些用户无法直接调用函数的情况; - 该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误; + 该函数使用用户的输入来对指定函数的参数进行赋值,主要用于一些用户无法直接调用函数的情况; + 该函数主要的作用在于帮助检查用户对使用函数 ``fn`` 的参数输入是否有误; - :param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值; - :param fn: 会被调用的函数; - :return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值; + :param user_params: 用户指定的参数的值,应当是一个字典,其中 ``key`` 表示每一个参数的名字, + ``value`` 为每一个参数的值; + :param fn: 将要被调用的函数; + :return: 返回一个字典,其中为在之后调用函数 ``fn`` 时真正会被传进去的参数的值; """ fn_arg_names = get_fn_arg_names(fn) @@ -243,6 +240,9 @@ def check_user_specific_params(user_params: Dict, fn: Callable): def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: + """ + 将传入的 `dataclass` 实例转换为字典。 + """ if not is_dataclass(data): raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") _dict = dict() @@ -253,21 +253,31 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: r""" - 用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; - 该函数应用于 `input_mapping` 和 `output_mapping`; - 对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; - 对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; + 用来实现将输入的 ``batch``,或者输出的 ``outputs``,通过 ``mapping`` 将键值进行更换的功能; + 该函数应用于 ``input_mapping`` 和 ``output_mapping``; - 转换的逻辑按优先级依次为: + 对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用; + 对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step` + 以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用; - 1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; - 2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; - 如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; - 如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; - 如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 - mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 + 转换的逻辑按优先级依次为: - :param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 + 1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``; + 2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``; + + * 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``; + * 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换; + * 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典:: + + { + "_0": list[0], + "_1": list[1], + ... + } + + 然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。 + + :param mapping: 用于转换的字典或者函数;``mapping`` 是函数时,返回值必须为字典类型。 :param data: 需要被转换的对象; :return: 返回转换好的结果; """ @@ -320,21 +330,20 @@ def apply_to_collection( include_none: bool = True, **kwargs: Any, ) -> Any: - """将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 - - this function credit to: https://github.com/PyTorchLightning/pytorch-lightning - Args: - data: the collection to apply the function to - dtype: the given function will be applied to all elements of this dtype - function: the function to apply - *args: positional arguments (will be forwarded to calls of ``function``) - wrong_dtype: the given function won't be applied if this type is specified and the given collections - is of the ``wrong_dtype`` even if it is of type ``dtype`` - include_none: Whether to include an element if the output of ``function`` is ``None``. - **kwargs: keyword arguments (will be forwarded to calls of ``function``) - - Returns: - The resulting collection + """ + 使用函数 ``function`` 递归地在 ``data`` 中的元素执行,但是仅在满足元素为 ``dtype`` 时执行。 + + 该函数参考了 `pytorch-lightning `_ 的实现 + + :param data: 需要进行处理的数据集合或数据 + :param dtype: 数据的类型,函数 ``function`` 只会被应用于 ``data`` 中类型为 ``dtype`` 的数据 + :param function: 对数据进行处理的函数 + :param args: ``function`` 所需要的其它参数 + :param wrong_dtype: ``function`` 一定不会生效的数据类型。如果数据既是 ``wrong_dtype`` 类型又是 ``dtype`` 类型 + 那么也不会生效。 + :param include_none: 是否包含执行结果为 ``None`` 的数据,默认为 ``True``。 + :param kwargs: ``function`` 所需要的其它参数 + :return: 经过 ``function`` 处理后的数据集合 """ # Breaking condition if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): @@ -402,16 +411,18 @@ def apply_to_collection( @contextmanager def nullcontext(): r""" - 用来实现一个什么 dummy 的 context 上下文环境; + 实现一个什么都不做的上下文环境 """ yield def sub_column(string: str, c: int, c_size: int, title: str) -> str: r""" + 对传入的字符串进行截断,方便在命令行中显示 + :param string: 要被截断的字符串 :param c: 命令行列数 - :param c_size: instance或dataset field数 + :param c_size: :class:`~fastNLP.core.Instance` 或 :class:`fastNLP.core.DataSet` 的 ``field`` 数目 :param title: 列名 :return: 对一个过长的列进行截断的结果 """ @@ -442,18 +453,17 @@ def _is_iterable(value): def pretty_table_printer(dataset_or_ins) -> PrettyTable: r""" - :param dataset_or_ins: 传入一个dataSet或者instance - - .. code-block:: + 在 ``fastNLP`` 中展示数据的函数:: - ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) + >>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) +-----------+-----------+-----------------+ | field_1 | field_2 | field_3 | +-----------+-----------+-----------------+ | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | +-----------+-----------+-----------------+ - :return: 以 pretty table的形式返回根据terminal大小进行自动截断 + :param dataset_or_ins: 要展示的 :class:`~fastNLP.core.DataSet` 或者 :class:`~fastNLP.core.Instance` + :return: 根据 ``terminal`` 大小进行自动截断的数据表格 """ x = PrettyTable() try: @@ -486,7 +496,7 @@ def pretty_table_printer(dataset_or_ins) -> PrettyTable: class Option(dict): - r"""a dict can treat keys as attributes""" + r"""将键转化为属性的字典类型""" def __getattr__(self, item): try: @@ -516,11 +526,10 @@ _emitted_deprecation_warnings = set() def deprecated(help_message: Optional[str] = None): - """Decorator to mark a function as deprecated. + """ + 标记当前功能已经过时的装饰器。 - Args: - help_message (`Optional[str]`): An optional message to guide the user on how to - switch to non-deprecated usage of the library. + :param help_message: 一段指引信息,告知用户如何将代码切换为当前版本提倡的用法。 """ def decorator(deprecated_function: Callable): @@ -549,11 +558,10 @@ def deprecated(help_message: Optional[str] = None): return decorator -def seq_len_to_mask(seq_len, max_len=None): +def seq_len_to_mask(seq_len, max_len: Optional[int]): r""" - 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 - 转变 1-d seq_len到2-d mask. + 将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。 .. code-block:: @@ -570,10 +578,11 @@ def seq_len_to_mask(seq_len, max_len=None): >>>print(mask.size()) torch.Size([14, 100]) - :param np.ndarray,torch.LongTensor seq_len: shape将是(B,) - :param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 - 区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 - :return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 + :param seq_len: 大小为是 ``(B,)`` 的长度序列 + :param int max_len: 将长度 ``pad`` 到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度。 + 但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 + 一个 ``max_len`` 使得 ``mask`` 的长度 ``pad`` 到该长度。 + :return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` """ if isinstance(seq_len, np.ndarray): assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." diff --git a/fastNLP/envs/utils.py b/fastNLP/envs/utils.py index 355c2448..3936771e 100644 --- a/fastNLP/envs/utils.py +++ b/fastNLP/envs/utils.py @@ -6,6 +6,7 @@ from packaging.version import Version import subprocess import pkg_resources +__all__ = [] def _module_available(module_path: str) -> bool: """Check if a path is available in your environment. @@ -48,10 +49,11 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: pkg_version = Version(pkg_version.base_version) return op(pkg_version, Version(version)) -def get_gpu_count(): +def get_gpu_count() -> int: """ - 利用命令行获取gpu数目的函数 - :return: gpu数目,如果没有显卡设备则为-1 + 利用命令行获取 ``gpu`` 数目的函数 + + :return: 显卡数目,如果没有显卡设备则为-1 """ try: lines = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.used', '--format=csv'])