diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py index 528ab529..a8103c62 100644 --- a/fastNLP/core/controllers/utils/state.py +++ b/fastNLP/core/controllers/utils/state.py @@ -1,19 +1,6 @@ -""" - -该 Module 用来实现一个用于记载用户 callback 实时数据的 state,该 state 实际上是一个 字典,我们通过复用 __getattr__ 方法来实现类似 -类属性的字典调用方式; - -提供该类的主要目的在于与 Filter 中的特殊的 filter_fn 合作,方便用户能够使用到自己想要的一切特殊的定制方式; - -这一特殊的 Filter 实现需要用户记录一些特殊的状态值,例如 accuracy 等,但是我们不希望用户将这些状态值直接挂在 trainer 实例上,因为这样会 -污染 trainer 自己的类属性,从而可能导致一些莫名其妙的 bug; - -我们开放 state 用于用户这一特殊的定制选择; -""" from dataclasses import dataclass from typing import Optional, Dict - __all__ = [ 'State', 'TrainerState' @@ -22,7 +9,8 @@ __all__ = [ class State(dict): r""" - 提供给用户使用的 state; + 提供给用户使用的 ``state``,用来记载您的 ``callback`` 实时数据,该 ``state`` 实际上是一个字典,我们通过复用 ``__getattr__`` 方法来实现类似 + 类属性的字典调用方式; 为了实现断点重训,用户应当保证其保存的信息都是可序列化的; diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index a2b2d5ae..ef3cf98c 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -1,4 +1,3 @@ -import inspect from typing import Dict from fastNLP.core.callbacks import CallbackManager @@ -7,10 +6,10 @@ from fastNLP.core.utils.utils import _check_valid_parameters_number class TrainerEventTrigger: - """ + r""" 为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer' - 抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 - 模型。 + 抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存 + 模型。 """ callback_manager: CallbackManager trainer_state: TrainerState @@ -90,13 +89,21 @@ class TrainerEventTrigger: class _TruncatedDataLoader: + r""" + ``_TruncatedDataLoader`` 用于实现 ``Trainer`` 和 ``Evaluator`` 中的 '预跑' 和 '假跑' 功能: + + 1. 预跑 是针对 trainer 的验证而言的,即我们在正式的训练前会先使用 trainer 内置的 evaluator(如果不为 None)评测数量非常少的数据, + 来检验用户的 metric 和 evaluate_dataloader 以及模型是否能够合作完成正确的评测过程; + 2. 假跑 的意思是改变每一个 epoch 中训练或者评测的实际的 batch 的数量,例如改成 10,来让模型快速地迭代整体的训练或者评测流程,来查看 + 整体的过程的正确性; + + ``_TruncatedDataLoader`` 的实现也非常简单,我们在该类中内置一个计数器,当迭代器的迭代数量达到指定数值后 ``raise StopIteration``; + + :param dataloader: 可迭代的 dataloader 。 + :param num_batches: 迭代多少个 batch 就停止。 + """ def __init__(self, dataloader, num_batches: int): - """ - 限制 - :param dataloader: 可迭代的 dataloader 。 - :param num_batches: 迭代多少个 batch 就停止。 - """ self.dataloader = dataloader self._num_batches = min(num_batches, len(dataloader)) self._count = 0 @@ -104,7 +111,6 @@ class _TruncatedDataLoader: def __len__(self): r""" 为了在外部调用 `len` 方法时正确地返回当前会迭代的长度; - """ return self._num_batches @@ -127,6 +133,13 @@ class _TruncatedDataLoader: def check_evaluate_every(evaluate_every): + r""" + 检验用户传入的 ``evaluate_every`` 参数是否合法; + + ``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数; + + 主要在于当参数 ``evaluate_every`` 是一个 callable 的函数时,需要保证其参数的正确性; + """ if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") if callable(evaluate_every):