@@ -10,16 +10,21 @@ from fastNLP.core.utils import match_and_substitute_params | |||
class EvaluateBatchLoop(Loop): | |||
r""" | |||
``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; | |||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||
""" | |||
def __init__(self, batch_step_fn:Optional[Callable]=None): | |||
if batch_step_fn is not None: | |||
self.batch_step_fn = batch_step_fn | |||
def run(self, evaluator, dataloader) -> Dict: | |||
""" | |||
r""" | |||
需要返回在传入的 dataloader 中的 evaluation 结果 | |||
:param evaluator: Evaluator 对象 | |||
:param dataloader: 当前需要进行 evaluate 的dataloader | |||
:param dataloader: 当前需要进行评测的dataloader | |||
:return: | |||
""" | |||
iterator = iter(dataloader) | |||
@@ -48,5 +53,11 @@ class EvaluateBatchLoop(Loop): | |||
@staticmethod | |||
def batch_step_fn(evaluator, batch): | |||
r""" | |||
针对一个 batch 的数据的评测过程; | |||
:param evaluator: Evaluator 对象 | |||
:param batch: 当前需要评测的一个 batch 的数据; | |||
""" | |||
outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | |||
evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 |
@@ -1,3 +1,12 @@ | |||
r""" | |||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 | |||
训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; | |||
更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 | |||
:class:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop` ; | |||
""" | |||
from typing import Union | |||
__all__ = [ | |||
'Loop' | |||
@@ -5,13 +14,25 @@ __all__ = [ | |||
class Loop: | |||
r""" | |||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,您可以继承此类来定制自己的训练或者评测 ``loop``; | |||
""" | |||
def run(self, *args, **kwargs): | |||
""" | |||
该循环的主要运行过程; | |||
""" | |||
def run(self, controller: Union["Trainer", "Evaluator"], dataloader): | |||
r""" | |||
遍历参数 ``dataloader`` 的所有数据,使用 ``controller`` 进行训练或者评测; | |||
.. note:: | |||
``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; | |||
在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` | |||
中对于 callback 函数的调用; | |||
def step(self, *args, **kwargs): | |||
""" | |||
该循环运行过程中的一步; | |||
@staticmethod | |||
def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): | |||
r""" | |||
对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; | |||
""" |
@@ -11,11 +11,27 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||
class TrainBatchLoop(Loop): | |||
r""" | |||
``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; | |||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||
""" | |||
def __init__(self, batch_step_fn: Optional[Callable] = None): | |||
if batch_step_fn is not None: | |||
self.batch_step_fn = batch_step_fn | |||
def run(self, trainer, dataloader): | |||
r""" | |||
对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; | |||
.. note:: | |||
您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; | |||
:param trainer: ``Trainer`` 实例; | |||
:param dataloader: 当前训练所使用的 dataloader; | |||
""" | |||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | |||
else lambda *args, **kwargs: None | |||
dataloader = iter(dataloader) | |||
@@ -49,6 +65,12 @@ class TrainBatchLoop(Loop): | |||
@staticmethod | |||
def batch_step_fn(trainer, batch): | |||
r""" | |||
针对一个 batch 的数据的训练过程; | |||
:param trainer: ``Trainer`` 实例; | |||
:param batch: 一个 batch 的数据; | |||
""" | |||
outputs = trainer.train_step(batch) | |||
trainer.backward(outputs) | |||
trainer.step() | |||