From 6665fee7c591ac07f747f9d88a397fc8fd7ad475 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 12 May 2022 15:52:44 +0800 Subject: [PATCH] =?UTF-8?q?Loop=20=E7=9A=84=E6=96=87=E6=A1=A3=E5=9F=BA?= =?UTF-8?q?=E6=9C=AC=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../controllers/loops/evaluate_batch_loop.py | 15 +++++++-- fastNLP/core/controllers/loops/loop.py | 33 +++++++++++++++---- .../controllers/loops/train_batch_loop.py | 22 +++++++++++++ 3 files changed, 62 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/controllers/loops/evaluate_batch_loop.py b/fastNLP/core/controllers/loops/evaluate_batch_loop.py index c81379a1..c6301772 100644 --- a/fastNLP/core/controllers/loops/evaluate_batch_loop.py +++ b/fastNLP/core/controllers/loops/evaluate_batch_loop.py @@ -10,16 +10,21 @@ from fastNLP.core.utils import match_and_substitute_params class EvaluateBatchLoop(Loop): + r""" + ``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; + + :param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; + """ def __init__(self, batch_step_fn:Optional[Callable]=None): if batch_step_fn is not None: self.batch_step_fn = batch_step_fn def run(self, evaluator, dataloader) -> Dict: - """ + r""" 需要返回在传入的 dataloader 中的 evaluation 结果 :param evaluator: Evaluator 对象 - :param dataloader: 当前需要进行 evaluate 的dataloader + :param dataloader: 当前需要进行评测的dataloader :return: """ iterator = iter(dataloader) @@ -48,5 +53,11 @@ class EvaluateBatchLoop(Loop): @staticmethod def batch_step_fn(evaluator, batch): + r""" + 针对一个 batch 的数据的评测过程; + + :param evaluator: Evaluator 对象 + :param batch: 当前需要评测的一个 batch 的数据; + """ outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 diff --git a/fastNLP/core/controllers/loops/loop.py b/fastNLP/core/controllers/loops/loop.py index 19f5ccc6..b1952236 100644 --- a/fastNLP/core/controllers/loops/loop.py +++ b/fastNLP/core/controllers/loops/loop.py @@ -1,3 +1,12 @@ +r""" +``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 +训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; + +更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 +:class:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop` ; +""" + +from typing import Union __all__ = [ 'Loop' @@ -5,13 +14,25 @@ __all__ = [ class Loop: + r""" + ``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,您可以继承此类来定制自己的训练或者评测 ``loop``; + """ - def run(self, *args, **kwargs): - """ - 该循环的主要运行过程; - """ + def run(self, controller: Union["Trainer", "Evaluator"], dataloader): + r""" + 遍历参数 ``dataloader`` 的所有数据,使用 ``controller`` 进行训练或者评测; + + .. note:: + + ``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; + + 在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` + 中对于 callback 函数的调用; - def step(self, *args, **kwargs): """ - 该循环运行过程中的一步; + + @staticmethod + def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): + r""" + 对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; """ \ No newline at end of file diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index 7bb9b653..48485226 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -11,11 +11,27 @@ from fastNLP.core.utils.exceptions import EarlyStopException class TrainBatchLoop(Loop): + r""" + ``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; + + :param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; + """ + def __init__(self, batch_step_fn: Optional[Callable] = None): if batch_step_fn is not None: self.batch_step_fn = batch_step_fn def run(self, trainer, dataloader): + r""" + 对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; + + .. note:: + + 您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; + + :param trainer: ``Trainer`` 实例; + :param dataloader: 当前训练所使用的 dataloader; + """ get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ else lambda *args, **kwargs: None dataloader = iter(dataloader) @@ -49,6 +65,12 @@ class TrainBatchLoop(Loop): @staticmethod def batch_step_fn(trainer, batch): + r""" + 针对一个 batch 的数据的训练过程; + + :param trainer: ``Trainer`` 实例; + :param batch: 一个 batch 的数据; + """ outputs = trainer.train_step(batch) trainer.backward(outputs) trainer.step()