Browse Source

添加了trainer 的文档1

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
00a6c52440
1 changed files with 184 additions and 51 deletions
  1. +184
    -51
      fastNLP/core/controllers/trainer.py

+ 184
- 51
fastNLP/core/controllers/trainer.py View File

@@ -1,14 +1,10 @@
""" """
todo 写一下这里的开头文档
``Trainer`` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产 ``Trainer`` 是 fastNLP 用于训练模型的专门的训练器,其支持多种不同的驱动模式 ``Driver``,不仅包括最为经常使用的 DDP,而且还支持 jittor 等国产
的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现 的训练框架;新版的 fastNLP 新加入了方便的 callback 函数修饰器,并且支持定制用户自己特定的训练循环过程;通过使用该训练器,用户只需要自己实现
模型部分,而将训练层面的逻辑完全地交给 fastNLP; 模型部分,而将训练层面的逻辑完全地交给 fastNLP;



""" """


from typing import Union, Optional, List, Callable, Dict, Sequence, BinaryIO, IO
from typing import Union, Optional, List, Callable, Dict, BinaryIO
from functools import partial from functools import partial
from collections import defaultdict from collections import defaultdict
import copy import copy
@@ -71,8 +67,6 @@ class Trainer(TrainerEventTrigger):
**kwargs **kwargs
): ):
r""" r"""


:param model: 训练所需要的模型,例如 ``torch.nn.Module``; :param model: 训练所需要的模型,例如 ``torch.nn.Module``;


.. note:: .. note::
@@ -160,60 +154,167 @@ class Trainer(TrainerEventTrigger):
:param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor, :param metrics: 用于传给 ``Trainer`` 内部的 ``Evaluator`` 实例来进行训练过程中的验证。其应当为一个字典,其中 key 表示 monitor,
例如 {"acc1": AccMetric(), "acc2": AccMetric()}; 例如 {"acc1": AccMetric(), "acc2": AccMetric()};


目前我们支持的 ``metric`` 的种类有以下几种:

1. fastNLP 自己的 ``metric``:详见 :class:`fastNLP.core.metrics.Metric`;
2. torchmetrics;
3. allennlp.training.metrics;
4. paddle.metric;

:param evaluate_every: 用来控制 ``Trainer`` 内部的 ``Evaluator`` 验证的频率,其可以为负数、正数或者函数:

1. 为负数时表示每隔几个 ``epoch`` evaluate 一次;
2. 为正数则表示每隔几个 ``batch`` evaluate 一次;
3. 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 ``batch`` 结束后调用该函数判断是否需要 evaluate;

.. note::

如果参数 ``evaluate_every`` 为函数,其应当类似:

>>> def my_evaluate_every(trainer) -> bool:
... if (trainer.global_forward_batches+1) % 1000 == 0:
... return True
... else:
... return False

该函数表示当每经过 1000 个 batch,``Trainer`` 中内置的 ``Evaluator`` 就会验证一次;

另一个需要注意的事情在于该函数会在每一次 batch 的结尾进行调用,当该函数返回 ``True`` 时,``Evaluator`` 才会进行验证;

:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理:

1. 如果 ``input_mapping`` 是一个字典:

1. 如果此时 batch 也是一个 ``Dict``,那么我们会把 batch 中同样在 ``input_mapping`` 中的 key 修改为 ``input_mapping`` 的对应 ``key`` 的 ``value``;
2. 如果此时 batch 是一个 ``dataclass``,那么我们会先将其转换为一个 ``Dict``,然后再进行上述转换;
3. 如果此时 batch 此时是其它类型,那么我们将会直接报错;
2. 如果 ``input_mapping`` 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;

注意该参数会被传进 ``Evaluator`` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时);
如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``input_mapping``, 请使用 ``train_input_mapping`` 与 ``evaluate_input_mapping`` 分别进行设置。

:param output_mapping: 应当为一个字典或者函数。作用和 ``input_mapping`` 类似,区别在于其用于转换输出:

1. 如果 ``output_mapping`` 是一个 ``Dict``,那么我们需要模型的输出必须是 ``Dict`` 或者 ``dataclass`` 类型:

1. 如果此时模型的输出是一个 ``Dict``,那么我们会把输出中同样在 ``output_mapping`` 中的 key 修改为 ``output_mapping`` 的对应 key 的 value;
2. 如果此时模型的输出是一个 ``dataclass``,那么我们会先将其转换为一个 Dict,然后再进行上述转换;
2. 如果 ``output_mapping`` 是一个函数,那么我们将会直接将模型的输出传给该函数;

如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``output_mapping``, 请使用 ``train_output_mapping`` 与 ``evaluate_output_mapping`` 分别进行设置;

.. note::

``input_mapping`` 和 ``output_mapping`` 与 fastNLP 的一个特殊的概念 **'参数绑定'** 高度相关,它们的存在也是为了 fastNLP
中的参数匹配能够正确地运行;

.. todo::
之后链接上 参数匹配 的文档;

.. warning::

如果 ``Trainer`` 的参数 ``output_mapping`` 不为 ``None``,请保证其返回的一定是一个字典,并且其中含有关键字 **'loss'**;

:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 ``auto_param_call`` 函数来自动匹配 batch 和前向函数的参数的行为;

1. 如果该值为 ``False``,并且当 batch 为字典时,我们会根据**前向函数**所需要的参数从 batch 中提取对应的对象,然后传入到**前向函数**中;
2. 如果该值为 ``True``,那么我们会将 batch 直接透传给模型;

.. todo::
之后链接上 参数匹配 的文档;

函数 ``auto_param_call`` 详见 :func:`fastNLP.core.utils.auto_param_call`;

:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1;
:param fp16: 是否开启混合精度训练,默认为 False;
:param monitor: 对于一些特殊的 ``Callback``,例如 :class:`fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor``
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 ``CheckpointCallback`` 而言,如果我们
想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置:


.. code-block::


trainer = Trainer(
...,
metrics={'acc': accMetric()},
callbacks=[CheckpointCallback(
...,
monitor='acc',
topk=1
)]
)

这意味着对于 ``CheckpointCallback`` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。

``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor``
参数更详细的说明,请见 :class:`fastNLP.core.callbacks.CheckpointCallback`;

注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;

:param larger_better: 对于需要参数 ``monitor`` 的 *callback* 来说,``monitor`` 的值是否是越大越好;类似于 ``monitor``,其作用
在于为没有设置 ``larger_better`` 参数但是需要该参数的 *callback* 实例设置该值;

注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;

:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None;

.. note::

marker 的使用场景主要在于如果一个脚本中含有多个 ``Trainer`` 实例,并且含有多个使用 ``Trainer.on`` 修饰的函数时,不同的函数属于
不同的 ``Trainer`` 实例;

此时,通过将修饰器 ``Trainer.on`` 的参数 ``marker`` 和 ``Trainer`` 的参数 ``marker`` 置为相同,就可以使得该函数只会在这一
``Trainer`` 实例中被调用;例如,

.. code-block::

@Trainer.on(Event.on_train_begin(), marker='trainer1')
def fn(trainer):
...

trainer = Trainer(
...,
marker='trainer1'
)

另一点需要说明的是,如果一个被 ``Trainer.on`` 修饰的函数,其修饰时没有指明 ``marker``,那么会将该函数传给代码位于其之后的
第一个 ``Trainer`` 实例,即使该 ``Trainer`` 实例的 marker 不为 None;这一点详见 :meth:`~fastNLP.core.controllers.Trainer.on`


:param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch evaluate 一次;为正数则表示每隔几个 batch evaluate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 evaluate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里;
注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时);
如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型,
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。
:param larger_better: monitor 的值是否是越大越好。
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
:param kwargs: 一些其它的可能需要的参数,见下方的说明
:kwargs: :kwargs:
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数:
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等;
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None;
* torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; * torch_non_blocking -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking;
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上;
注意如果 model_device 为 None,那么 data_device 不会起作用;
* *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
* *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True;
* *data_device* -- 一个具体的 driver 实例中,有 ``model_device`` 和 ``data_device``,前者表示模型所在的设备,后者表示
当 ``model_device`` 为 None 时应当将数据迁移到哪个设备;

.. note::

注意您在绝大部分情况下不会用到该参数!

1. 当 driver 实例的 ``model_device`` 不为 None 时,该参数无效;
2. 对于 pytorch,仅当用户自己通过 ``python -m torch.distributed.launch`` 并且自己初始化 ``init_process_group`` 时,
driver 实例的 ``model_device`` 才会为 None;

* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;默认为 ``True``;
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。


注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``;
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Evaluator`` 中。与 output_mapping 互斥。


.. note:: .. note::

``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证;
``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入:


必须的参数:1. ``metrics``;2. ``evaluate_dataloaders``; 必须的参数:1. ``metrics``;2. ``evaluate_dataloaders``;
@@ -223,15 +324,17 @@ class Trainer(TrainerEventTrigger):


.. warning:: .. warning::


如果 ``Trainer`` 中的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为:
如果 ``Trainer`` 中内置的 ``Evaluator`` 实例不为 ``None``,那么需要注意 ``Trainer`` 中的一些参数是与 ``Evaluator`` 一致的,它们分别为:


1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的 1. ``Evaluator`` 在初始化时的 ``driver`` 参数是 ``Trainer`` 中已经实例化过的 driver;这一点使得一些参数对于 ``Trainer`` 内部的
``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等; ``Evaluator`` 没有用处,例如 ``device``,``torch_kwargs``,``data_device`` 和 ``output_from_new_proc`` 等;
2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的
2. ``input_mapping``,``output_mapping``,``model_wo_auto_param_call`` 和 ``fp16`` 是 ``Trainer`` 和其内部默认的
``Evaluator`` 是一致的; ``Evaluator`` 是一致的;


当然,对于某些参数,您可以通过修改
当然,对于 ``input_mapping`` 和 ``output_mapping``,您可以通过添加 ``kwargs`` 中的参数 ``evaluate_input_mapping`` 和
``evaluate_output_mapping`` 来单独为 ``Evaluator`` 进行更细致的订制。


另一方面,注意一些专门独属于 ``Evaluator`` 的参数仅当 ``Evaluator`` 不为 None 时才会生效。


""" """
self.model = model self.model = model
@@ -352,7 +455,7 @@ class Trainer(TrainerEventTrigger):
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None),
@@ -381,7 +484,7 @@ class Trainer(TrainerEventTrigger):
def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1,
num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True,
catch_KeyboardInterrupt=None): catch_KeyboardInterrupt=None):
"""
r"""
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint
去保存断点重训的文件; 去保存断点重训的文件;
:param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。 :param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。
@@ -570,6 +673,36 @@ class Trainer(TrainerEventTrigger):
# do something # do something
# 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。 # 以上函数会在 Trainer 每个新的 batch 开始的时候执行,但是是两个 batch 才执行一次。


.. note::


例如:

.. code-block::

@Trainer.on(Event.on_train_begin())
def fn1(trainer):
...

@Trainer.on(Event.on_train_epoch_begin())
def fn2(trainer):
...

trainer1 = Trainer(
...,
marker='trainer1'
)

@Trainer.on(Event.on_fetch_data_begin())
def fn3(trainer):
...

trainer2 = Trainer(
...,
marker='trainer2'
)


注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前;


:param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含 :param event: 特定的 callback 时机,用户需要为该 callback 函数指定其属于哪一个 callback 时机。每个时机运行的函数应该包含


Loading…
Cancel
Save