@@ -10,16 +10,21 @@ from fastNLP.core.utils import match_and_substitute_params | |||||
class EvaluateBatchLoop(Loop): | class EvaluateBatchLoop(Loop): | ||||
r""" | |||||
``EvaluateBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的评测迭代过程; | |||||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||||
""" | |||||
def __init__(self, batch_step_fn:Optional[Callable]=None): | def __init__(self, batch_step_fn:Optional[Callable]=None): | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
def run(self, evaluator, dataloader) -> Dict: | def run(self, evaluator, dataloader) -> Dict: | ||||
""" | |||||
r""" | |||||
需要返回在传入的 dataloader 中的 evaluation 结果 | 需要返回在传入的 dataloader 中的 evaluation 结果 | ||||
:param evaluator: Evaluator 对象 | :param evaluator: Evaluator 对象 | ||||
:param dataloader: 当前需要进行 evaluate 的dataloader | |||||
:param dataloader: 当前需要进行评测的dataloader | |||||
:return: | :return: | ||||
""" | """ | ||||
iterator = iter(dataloader) | iterator = iter(dataloader) | ||||
@@ -48,5 +53,11 @@ class EvaluateBatchLoop(Loop): | |||||
@staticmethod | @staticmethod | ||||
def batch_step_fn(evaluator, batch): | def batch_step_fn(evaluator, batch): | ||||
r""" | |||||
针对一个 batch 的数据的评测过程; | |||||
:param evaluator: Evaluator 对象 | |||||
:param batch: 当前需要评测的一个 batch 的数据; | |||||
""" | |||||
outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | outputs = evaluator.evaluate_step(batch) # 将batch输入到model中得到结果 | ||||
evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 | evaluator.update(batch, outputs) # evaluator将根据metric的形参名字从batch/outputs中取出对应的值进行赋值 |
@@ -1,3 +1,12 @@ | |||||
r""" | |||||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,为了在实现 fastNLP 主要功能的同时保证 fastNLP 的易用性和代码的易读性,我们只对 | |||||
训练中的循环做了非常简单的抽象,``Loop`` 表示的是在训练或者评测的过程中针对单独一个 ``dataloader`` 的一个 ``epoch`` 的运算过程; | |||||
更为具体的使用详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` 和 | |||||
:class:`~fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop` ; | |||||
""" | |||||
from typing import Union | |||||
__all__ = [ | __all__ = [ | ||||
'Loop' | 'Loop' | ||||
@@ -5,13 +14,25 @@ __all__ = [ | |||||
class Loop: | class Loop: | ||||
r""" | |||||
``TrainBatchLoop`` 和 ``EvaluateBatchLoop`` 的父类,您可以继承此类来定制自己的训练或者评测 ``loop``; | |||||
""" | |||||
def run(self, *args, **kwargs): | |||||
""" | |||||
该循环的主要运行过程; | |||||
""" | |||||
def run(self, controller: Union["Trainer", "Evaluator"], dataloader): | |||||
r""" | |||||
遍历参数 ``dataloader`` 的所有数据,使用 ``controller`` 进行训练或者评测; | |||||
.. note:: | |||||
``Trainer`` 和 ``Evaluator`` 中都提供了方便您进行定制 ``Loop`` 的接口函数,例如 ``Trainer.train_step``,``Trainer.backward``等; | |||||
在定制您自己的 ``TrainBatchLoop`` 时,请务必记得在正确的时机调用对应的 callback 函数,详见 :class:`~fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop` | |||||
中对于 callback 函数的调用; | |||||
def step(self, *args, **kwargs): | |||||
""" | """ | ||||
该循环运行过程中的一步; | |||||
@staticmethod | |||||
def batch_step_fn(controller: Union["Trainer", "Evaluator"], batch): | |||||
r""" | |||||
对于具体的一个 batch 的数据,实现训练或者评测过程中的一步; | |||||
""" | """ |
@@ -11,11 +11,27 @@ from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class TrainBatchLoop(Loop): | class TrainBatchLoop(Loop): | ||||
r""" | |||||
``TrainBatchLoop`` 针对一个 dataloader 的数据完成一个 epoch 的训练迭代过程; | |||||
:param batch_step_fn: 您可以传入该参数来替换默认的 bath_step_fn; | |||||
""" | |||||
def __init__(self, batch_step_fn: Optional[Callable] = None): | def __init__(self, batch_step_fn: Optional[Callable] = None): | ||||
if batch_step_fn is not None: | if batch_step_fn is not None: | ||||
self.batch_step_fn = batch_step_fn | self.batch_step_fn = batch_step_fn | ||||
def run(self, trainer, dataloader): | def run(self, trainer, dataloader): | ||||
r""" | |||||
对传入的 dataloader 进行一个 epoch 的主要的训练的循环过程; | |||||
.. note:: | |||||
您不需要自己主动地调用该方法,``Trainer`` 会负责调用该方法来完成训练过程; | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param dataloader: 当前训练所使用的 dataloader; | |||||
""" | |||||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | ||||
else lambda *args, **kwargs: None | else lambda *args, **kwargs: None | ||||
dataloader = iter(dataloader) | dataloader = iter(dataloader) | ||||
@@ -49,6 +65,12 @@ class TrainBatchLoop(Loop): | |||||
@staticmethod | @staticmethod | ||||
def batch_step_fn(trainer, batch): | def batch_step_fn(trainer, batch): | ||||
r""" | |||||
针对一个 batch 的数据的训练过程; | |||||
:param trainer: ``Trainer`` 实例; | |||||
:param batch: 一个 batch 的数据; | |||||
""" | |||||
outputs = trainer.train_step(batch) | outputs = trainer.train_step(batch) | ||||
trainer.backward(outputs) | trainer.backward(outputs) | ||||
trainer.step() | trainer.step() | ||||