Browse Source

Loop 的文档基本完成

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
6665fee7c5
3 changed files with 62 additions and 8 deletions
  1. +13
    -2
      fastNLP/core/controllers/loops/evaluate_batch_loop.py
  2. +27
    -6
      fastNLP/core/controllers/loops/loop.py
  3. +22
    -0
      fastNLP/core/controllers/loops/train_batch_loop.py

+ 13
- 2
fastNLP/core/controllers/loops/evaluate_batch_loop.py View File

@@ -10,16 +10,21 @@ from fastNLP.core.utils import match_and_substitute_params


class EvaluateBatchLoop(Loop):
r"""
``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程;

:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn;
"""
def __init__(self, batch_step_fn:Optional[Callable]=None):
if batch_step_fn is not None:
self.batch_step_fn = batch_step_fn

def run(self, evaluator, dataloader) -> Dict:
"""
r"""
需要返回在传入的 dataloader 中的 evaluation 结果

:param evaluator: Evaluator 对象
:param dataloader: 当前需要进行 evaluate 的dataloader
:param dataloader: 当前需要进行评测的dataloader
:return:
"""
iterator = iter(dataloader)
@@ -48,5 +53,11 @@ class EvaluateBatchLoop(Loop):

@staticmethod
def batch_step_fn(evaluator, batch):
r"""
针对一个 batch 的数据的评测过程;

:param evaluator: Evaluator 对象
:param batch: 当前需要评测的一个 batch 的数据;
"""
outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果
evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值

+ 27
- 6
fastNLP/core/controllers/loops/loop.py View File

@@ -1,3 +1,12 @@
r"""
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对
训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程;

更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和
:class:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop` ;
"""

from typing import Union

__all__ = [
'Loop'
@@ -5,13 +14,25 @@ __all__ = [


class Loop:
r"""
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,您可以继承此类来定制自己的训练或者评测 ``loop``;
"""

def run(self, *args, **kwargs):
"""
该循环的主要运行过程;
"""
def run(self, controller: Union["Trainer", "Evaluator"], dataloader):
r"""
遍历参数 ``dataloader`` 的所有数据,使用 ``controller`` 进行训练或者评测;

.. note::

``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等;

在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop`
中对于 callback 函数的调用;

def step(self, *args, **kwargs):
"""
该循环运行过程中的一步;

@staticmethod
def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch):
r"""
对于具体的一个 batch 的数据,实现训练或者评测过程中的一步;
"""

+ 22
- 0
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -11,11 +11,27 @@ from fastNLP.core.utils.exceptions import EarlyStopException


class TrainBatchLoop(Loop):
r"""
``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程;

:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn;
"""

def __init__(self, batch_step_fn: Optional[Callable] = None):
if batch_step_fn is not None:
self.batch_step_fn = batch_step_fn

def run(self, trainer, dataloader):
r"""
对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程;

.. note::

您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程;

:param trainer: ``Trainer`` 实例;
:param dataloader: 当前训练所使用的 dataloader;
"""
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\
else lambda *args, **kwargs: None
dataloader = iter(dataloader)
@@ -49,6 +65,12 @@ class TrainBatchLoop(Loop):

@staticmethod
def batch_step_fn(trainer, batch):
r"""
针对一个 batch 的数据的训练过程;

:param trainer: ``Trainer`` 实例;
:param batch: 一个 batch 的数据;
"""
outputs = trainer.train_step(batch)
trainer.backward(outputs)
trainer.step()


Loading…
Cancel
Save