| @@ -10,16 +10,21 @@ from fastNLP.core.utils import match_and_substitute_params | |||
| class EvaluateBatchLoop(Loop): | |||
| r""" | |||
| ``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; | |||
| :param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||
| """ | |||
| def __init__(self, batch_step_fn:Optional[Callable]=None): | |||
| if batch_step_fn is not None: | |||
| self.batch_step_fn = batch_step_fn | |||
| def run(self, evaluator, dataloader) -> Dict: | |||
| """ | |||
| r""" | |||
| 需要返回在传入的 dataloader 中的 evaluation 结果 | |||
| :param evaluator: Evaluator 对象 | |||
| :param dataloader: 当前需要进行 evaluate 的dataloader | |||
| :param dataloader: 当前需要进行评测的dataloader | |||
| :return: | |||
| """ | |||
| iterator = iter(dataloader) | |||
| @@ -48,5 +53,11 @@ class EvaluateBatchLoop(Loop): | |||
| @staticmethod | |||
| def batch_step_fn(evaluator, batch): | |||
| r""" | |||
| 针对一个 batch 的数据的评测过程; | |||
| :param evaluator: Evaluator 对象 | |||
| :param batch: 当前需要评测的一个 batch 的数据; | |||
| """ | |||
| outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | |||
| evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 | |||
| @@ -1,3 +1,12 @@ | |||
| r""" | |||
| ``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 | |||
| 训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; | |||
| 更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 | |||
| :class:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop` ; | |||
| """ | |||
| from typing import Union | |||
| __all__ = [ | |||
| 'Loop' | |||
| @@ -5,13 +14,25 @@ __all__ = [ | |||
| class Loop: | |||
| r""" | |||
| ``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,您可以继承此类来定制自己的训练或者评测 ``loop``; | |||
| """ | |||
| def run(self, *args, **kwargs): | |||
| """ | |||
| 该循环的主要运行过程; | |||
| """ | |||
| def run(self, controller: Union["Trainer", "Evaluator"], dataloader): | |||
| r""" | |||
| 遍历参数 ``dataloader`` 的所有数据,使用 ``controller`` 进行训练或者评测; | |||
| .. note:: | |||
| ``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; | |||
| 在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` | |||
| 中对于 callback 函数的调用; | |||
| def step(self, *args, **kwargs): | |||
| """ | |||
| 该循环运行过程中的一步; | |||
| @staticmethod | |||
| def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): | |||
| r""" | |||
| 对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; | |||
| """ | |||
| @@ -11,11 +11,27 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||
| class TrainBatchLoop(Loop): | |||
| r""" | |||
| ``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; | |||
| :param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||
| """ | |||
| def __init__(self, batch_step_fn: Optional[Callable] = None): | |||
| if batch_step_fn is not None: | |||
| self.batch_step_fn = batch_step_fn | |||
| def run(self, trainer, dataloader): | |||
| r""" | |||
| 对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; | |||
| .. note:: | |||
| 您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; | |||
| :param trainer: ``Trainer`` 实例; | |||
| :param dataloader: 当前训练所使用的 dataloader; | |||
| """ | |||
| get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | |||
| else lambda *args, **kwargs: None | |||
| dataloader = iter(dataloader) | |||
| @@ -49,6 +65,12 @@ class TrainBatchLoop(Loop): | |||
| @staticmethod | |||
| def batch_step_fn(trainer, batch): | |||
| r""" | |||
| 针对一个 batch 的数据的训练过程; | |||
| :param trainer: ``Trainer`` 实例; | |||
| :param batch: 一个 batch 的数据; | |||
| """ | |||
| outputs = trainer.train_step(batch) | |||
| trainer.backward(outputs) | |||
| trainer.step() | |||