Browse Source

controllers.utils 的文档基本完成

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
8de1477aa1
2 changed files with 25 additions and 24 deletions
  1. +2
    -14
      fastNLP/core/controllers/utils/state.py
  2. +23
    -10
      fastNLP/core/controllers/utils/utils.py

+ 2
- 14
fastNLP/core/controllers/utils/state.py View File

@@ -1,19 +1,6 @@
"""

该 Module 用来实现一个用于记载用户 callback 实时数据的 state,该 state 实际上是一个 字典,我们通过复用 __getattr__ 方法来实现类似
类属性的字典调用方式;

提供该类的主要目的在于与 Filter 中的特殊的 filter_fn 合作,方便用户能够使用到自己想要的一切特殊的定制方式;

这一特殊的 Filter 实现需要用户记录一些特殊的状态值,例如 accuracy 等,但是我们不希望用户将这些状态值直接挂在 trainer 实例上,因为这样会
污染 trainer 自己的类属性,从而可能导致一些莫名其妙的 bug;

我们开放 state 用于用户这一特殊的定制选择;
"""
from dataclasses import dataclass
from typing import Optional, Dict


__all__ = [
'State',
'TrainerState'
@@ -22,7 +9,8 @@ __all__ = [

class State(dict):
r"""
提供给用户使用的 state;
提供给用户使用的 ``state``,用来记载您的 ``callback`` 实时数据,该 ``state`` 实际上是一个字典,我们通过复用 ``__getattr__`` 方法来实现类似
类属性的字典调用方式;

为了实现断点重训,用户应当保证其保存的信息都是可序列化的;



+ 23
- 10
fastNLP/core/controllers/utils/utils.py View File

@@ -1,4 +1,3 @@
import inspect
from typing import Dict

from fastNLP.core.callbacks import CallbackManager
@@ -7,10 +6,10 @@ from fastNLP.core.utils.utils import _check_valid_parameters_number


class TrainerEventTrigger:
"""
r"""
为了避免在训练流程中调用 callback 函数中写成类似 'trainer.callback_manager.on_train_begin' 的形式,我们选择单独抽象为 'Trainer'
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存
模型。
抽象一层,然后一些特殊的操作可以在这里进行,例如我们通过 `on_validate_end` 来通知所有的 'CheckpointCallback' 实例在当前的 step 后保存
模型。
"""
callback_manager: CallbackManager
trainer_state: TrainerState
@@ -90,13 +89,21 @@ class TrainerEventTrigger:


class _TruncatedDataLoader:
r"""
``_TruncatedDataLoader`` 用于实现 ``Trainer`` 和 ``Evaluator`` 中的 '预跑' 和 '假跑' 功能:

1. 预跑 是针对 trainer 的验证而言的,即我们在正式的训练前会先使用 trainer 内置的 evaluator(如果不为 None)评测数量非常少的数据,
来检验用户的 metric 和 evaluate_dataloader 以及模型是否能够合作完成正确的评测过程;
2. 假跑 的意思是改变每一个 epoch 中训练或者评测的实际的 batch 的数量,例如改成 10,来让模型快速地迭代整体的训练或者评测流程,来查看
整体的过程的正确性;

``_TruncatedDataLoader`` 的实现也非常简单,我们在该类中内置一个计数器,当迭代器的迭代数量达到指定数值后 ``raise StopIteration``;

:param dataloader: 可迭代的 dataloader 。
:param num_batches: 迭代多少个 batch 就停止。
"""
def __init__(self, dataloader, num_batches: int):
"""
限制

:param dataloader: 可迭代的 dataloader 。
:param num_batches: 迭代多少个 batch 就停止。
"""
self.dataloader = dataloader
self._num_batches = min(num_batches, len(dataloader))
self._count = 0
@@ -104,7 +111,6 @@ class _TruncatedDataLoader:
def __len__(self):
r"""
为了在外部调用 `len` 方法时正确地返回当前会迭代的长度;

"""
return self._num_batches

@@ -127,6 +133,13 @@ class _TruncatedDataLoader:


def check_evaluate_every(evaluate_every):
r"""
检验用户传入的 ``evaluate_every`` 参数是否合法;

``evaluate_every`` 的使用详见 ``Trainer`` 的 ``evaluate_every`` 参数;

主要在于当参数 ``evaluate_every`` 是一个 callable 的函数时,需要保证其参数的正确性;
"""
if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0):
raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(evaluate_every):


Loading…
Cancel
Save