From 27cddb77be949d411679bc98de08f448b7dd68af Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 10 May 2022 21:32:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8A=A0=E4=BA=86=20Trainer=20=E7=9A=84?= =?UTF-8?q?=E4=B8=80=E4=BA=9B=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/.DS_Store | Bin 0 -> 6148 bytes fastNLP/core/controllers/trainer.py | 191 +++++++++++++++++++++------- 2 files changed, 142 insertions(+), 49 deletions(-) create mode 100644 fastNLP/core/.DS_Store diff --git a/fastNLP/core/.DS_Store b/fastNLP/core/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..2a1e21a54e6b4dd874e5da36ad493733ed4d458f GIT binary patch literal 6148 zcmeHKJ8Hu~5S>X>7~H5#xmWNF7UP`27ce-+jgUZ2Qme|jayxO`7mV%)H&v zycK$dMk6A+|9IPpv?8*B8_Ji3rP+P+g*|0PfpDBL${-!C3s&pQxLY}c*US5ifBQU+ zeK+*iIDOr5tfB%`fC^9nDnJDuC}6!8w%G@-0V?pX6wu~yJnZpG*;^McXT7$-PjIXGhMQsS6a;U_KySy` gSUZ04qNppj#{HT&1v(vhrvv#jV7kz#z+Wry13XO@IsgCw literal 0 HcmV?d00001 diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 9c9a859a..c9851e70 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -483,18 +483,62 @@ class Trainer(TrainerEventTrigger): def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, - catch_KeyboardInterrupt=None): + catch_KeyboardInterrupt = None): r""" - 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint + 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数; + + 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback`` 去保存断点重训的文件; - :param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。 - :param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止,-1 为根据 dataloader 有多少个 batch 决定。 - :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误。为 0 表示不检测。 - :param resume_from: 从哪个路径下恢复 trainer 的状态 - :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 - :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 - 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) - :return: + + :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度; + :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度; + :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测; + :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹; + :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``, + 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的 + 其余状态都是保持初始化时的状态不会改变; + :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序, + ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver`` + 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True; + + .. warning:: + + 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时 + ``trainer.cur_epoch_idx == trainer.n_epochs``; + + 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``; + + .. note:: + + 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后, + 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度,在验证整体的训练流程没有错误后,再将 + 该值设定为 **-1** 开始真正的训练; + + ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证 + 整体的验证流程是否正确; + + ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用; + 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的 + 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator`` + 进行验证时会验证的 batch 的数量。 + + 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch`` + 应当为一个很小的正整数,例如 2; + + .. note:: + + 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效; + + 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch + 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练; + + fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分: + + 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`; + ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成; + 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``; + + 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。 """ if catch_KeyboardInterrupt is None: @@ -569,7 +613,12 @@ class Trainer(TrainerEventTrigger): finally: self.on_train_end() - def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): + def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl: int): + r""" + 用于设定训练过程中 ``Evaluator`` 进行验证时所实际验证的 batch 的数量; + + :param num_eval_batch_per_dl: 等价于 :meth:`~fastNLP.core.controllers.Trainer.run` 中的参数 ``num_eval_batch_per_dl``; + """ def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: trainer.on_evaluate_begin() _evaluate_res: dict = evaluate_fn() @@ -579,10 +628,8 @@ class Trainer(TrainerEventTrigger): self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) def step_evaluate(self): - """ - 在每个 batch 结束后调用,根据设置执行 evaluate 。 - - :return: + r""" + 在训练过程中的每个 batch 结束后被调用,注意实际的 ``Evaluator.run`` 函数是否在此时被调用取决于用户设置的 **"验证频率"**; """ if self.evaluator is not None: if callable(self.evaluate_every): @@ -592,10 +639,8 @@ class Trainer(TrainerEventTrigger): self.run_evaluate() def epoch_evaluate(self): - """ - 在每个 epoch 结束后调用,根据设置执行 evaluate 。 - - :return: + r""" + 在训练过程中的每个 epoch 结束后被调用,注意实际的 ``Evaluator.run`` 函数是否在此时被调用取决于用户设置的 **"验证频率"**; """ if self.evaluator is not None: if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: @@ -605,11 +650,52 @@ class Trainer(TrainerEventTrigger): def add_callback_fn(self, event: Event, fn: Callable): r""" - 在初始化一个 trainer 实例后,用户可以使用这一函数来方便地添加 callback 函数; - 这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; + 在初始化一个 trainer 实例后,您可以使用这一函数来方便地添加 ``callback`` 函数; + + 注意这一函数应当交给具体的 trainer 实例去做,因此不需要 `mark` 参数; - :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机; + :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机;具体有哪些时机详见 :class:`fastNLP.core.callbacks.Event`; :param fn: 具体的 callback 函数; + + .. note:: + + 对于训练一个神经网络的整体的流程来说,其可以分为很多个时间点,例如 **"整体的训练前"**,**"训练具体的一个 epoch 前"**, + **"反向传播前"**,**"整体的训练结束后"**等;一个 ``callback`` 时机指的就是这些一个个具体的时间点; + + 该函数的参数 ``event`` 需要是一个 ``Event`` 实例,其使用方式见下方的例子; + + 一个十分需要注意的事情在于您需要保证您添加的 callback 函数 ``fn`` 的参数与对应的 callback 时机所需要的参数保持一致,更准确地说, + 是与 :class:`fastNLP.core.callbacks.Callback` 中的对应的 callback 函数的参数保持一致;例如如果 + 您想要在 ``on_after_trainer_initialized`` 这个时机添加一个您自己的 callback 函数,您需要保证其参数为 ``trainer, driver``; + + 最后用一句话总结:对于您想要加入的一个 callback 函数,您首先需要确定您想要将该函数加入的 callback 时机,然后通过 ``Event.on_***()`` + 拿到具体的 event 实例;再去 :class:`fastNLP.core.callbacks.Callback` 中确定该 callback 时机的 callback 函数的参数应当是怎样的; + + 例如: + + .. code-block:: + + from fastNLP import Trainer, Event + + # Trainer 初始化 + trainer = Trainer(...) + + # 定义您自己的 callback 函数,需要注意的是该函数的参数需要与您要添加的 callback 时机所需要的参数保持一致;因为我们要将该函数加入到 + # on_after_trainer_initialized 这个 callback 时机,因此我们这里的 + def my_callback_fn(trainer, driver): + # do something + # 您可以在函数内部使用 trainer 和 driver,我们会将这两个实例注入进去; + + # 添加到 trainer 中; + trainer.add_callback_fn(Event.on_after_trainer_initialized(), my_callback_fn) + + .. note:: + + 该函数与 ``Trainer.on`` 函数提供的作用相同,它们所需要的参数也基本相同,区别在于 ``Trainer.on`` 用于 ``Trainer`` 初始化前,而 + ``Trainer.add_callback_fn`` 用于 ``Trainer`` 初始化之后; + + 更为具体的解释见 :meth:`~fastNLP.core.controllers.Trainer.on`; + """ if not isinstance(event, Event): raise ValueError("parameter event should only be `Event` type.") @@ -621,6 +707,7 @@ class Trainer(TrainerEventTrigger): def on(cls, event: Event, marker: Optional[str] = None): r""" 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; + 支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如:: Trainer.__init__(): @@ -655,7 +742,15 @@ class Trainer(TrainerEventTrigger): on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 特定的时间调用。 - Example:: + .. note:: + + 对于 event 的解释,建议先阅读 :meth:`~fastNLP.core.controllers.Trainer.add_callback_fn` 的文档; + + 当生成一个具体的 ``Event`` 实例时,可以指定 ``every、once、filter_fn`` 这三个参数来控制您的 callback 函数的调用频率,例如当您 + 指定 ``Event.on_train_epoch_begin(every=3)`` 时,其表示每隔三个 epoch 运行一次您的 callback 函数;对于这三个参数的更具体的解释, + 请见 :class:`fastNLP.core.callbacks.Event`; + + Example1:: from fastNLP import Event @Trainer.on(Event.on_save_model()) @@ -673,42 +768,40 @@ class Trainer(TrainerEventTrigger): # do something # 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。 - .. note:: - - - 例如: - - .. code-block:: + Example2:: - @Trainer.on(Event.on_train_begin()) - def fn1(trainer): - ... + @Trainer.on(Event.on_train_begin()) + def fn1(trainer): + ... - @Trainer.on(Event.on_train_epoch_begin()) - def fn2(trainer): - ... + @Trainer.on(Event.on_train_epoch_begin()) + def fn2(trainer): + ... - trainer1 = Trainer( - ..., - marker='trainer1' - ) + trainer1 = Trainer( + ..., + marker='trainer1' + ) - @Trainer.on(Event.on_fetch_data_begin()) - def fn3(trainer): - ... + @Trainer.on(Event.on_fetch_data_begin()) + def fn3(trainer): + ... - trainer2 = Trainer( - ..., - marker='trainer2' - ) + trainer2 = Trainer( + ..., + marker='trainer2' + ) + 这段代码意味着 ``fn1`` 和 ``fn2`` 会被加入到 ``trainer1``,``fn3`` 会被加入到 ``trainer2``; 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; + 补充性的解释见 :meth:`~fastNLP.core.controllers.Trainer.add_callback_fn`; + :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含 特定的参数,可以通过上述说明查阅。 - :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 `marker` 为 None(默认情况)时, - 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 `marker` 为 'all' 时,该 callback 函数会被所有的 trainer + :param marker: 用来标记该 callback 函数属于哪几个具体的 trainer 实例;两个特殊情况:1.当 ``marker`` 为 None(默认情况)时, + 表示该 callback 函数只属于代码下方最近的一个 trainer 实例;2.当 ``marker`` 为 'all' 时,该 callback 函数会被所有的 trainer 实例使用; :return: 返回原函数; """ @@ -722,7 +815,7 @@ class Trainer(TrainerEventTrigger): return wrapper def _fetch_matched_fn_callbacks(self): - """ + r""" 因为对于使用装饰器加入的函数 callback,我们是加在类属性中,因此在初始化一个具体的 trainer 实例后,我们需要从 Trainer 的 callback 类属性中将属于其的 callback 函数拿到,然后加入到 callback_manager 中; """