@@ -33,9 +33,16 @@ class CheckpointCallback(Callback): | |||||
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | 则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | ||||
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | 的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | ||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param every_n_epochs: 多少个 epoch 保存一次。 | :param every_n_epochs: 多少个 epoch 保存一次。 | ||||
@@ -12,9 +12,16 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | ||||
""" | """ | ||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param patience: 多少次 evaluate 不没有提升就停止。 | :param patience: 多少次 evaluate 不没有提升就停止。 | ||||
""" | """ | ||||
@@ -0,0 +1,54 @@ | |||||
from fastNLP import HasMonitorCallback | |||||
import fitlog | |||||
class FitlogCallback(HasMonitorCallback): | |||||
def __init__(self, monitor=None, larger_better: bool = True, log_exception:bool=True, log_loss_every:int=0): | |||||
""" | |||||
自动记录 ``evaluation`` 结果到 ``fitlog`` 中的 ``Callback`` 。会根据 ``monitor`` 记录最好的结果,以及每一次 ``evaluate`` 后的 | |||||
结果。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是越大越好。 | |||||
:param log_exception: 是否记录 ``exception`` 。 | |||||
:param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 | |||||
""" | |||||
super().__init__(monitor=monitor, larger_better=larger_better) | |||||
self.log_exception = log_exception | |||||
self.log_loss_every = log_loss_every | |||||
self.avg_loss = 0 | |||||
def on_evaluate_end(self, trainer, results): | |||||
results = self.itemize_results(results) | |||||
fitlog.add_metric(results, step=trainer.global_forward_batches, epoch=trainer.cur_epoch_idx) | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
results['step'] = trainer.global_forward_batches | |||||
results['epoch'] = trainer.cur_epoch_idx | |||||
fitlog.add_best_metric(results) | |||||
def on_before_backward(self, trainer, outputs): | |||||
if self.log_loss_every > 0: | |||||
loss = trainer.extract_loss_from_outputs(outputs) | |||||
self.avg_loss += loss.item() | |||||
if trainer.global_forward_batches % self.log_loss_every == 0: | |||||
fitlog.add_loss(self.avg_loss / self.log_loss_every * trainer.accumulation_steps, name='loss', | |||||
step=trainer.global_forward_batches, | |||||
epoch=trainer.cur_epoch_idx) | |||||
self.avg_loss = 0 | |||||
def on_train_end(self, trainer): | |||||
fitlog.finish() | |||||
def on_exception(self, trainer, exception): | |||||
fitlog.finish(status=1) | |||||
if self.log_exception: | |||||
fitlog.add_other(repr(exception), name='except_info') |
@@ -171,9 +171,16 @@ class HasMonitorCallback(ResultsMonitor, Callback): | |||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | ||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | ||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | :param larger_better: monitor 是否时越大越好 | ||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | ||||
""" | """ | ||||
@@ -22,9 +22,16 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | ||||
最好的模型。 | 最好的模型。 | ||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -45,10 +45,16 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | :param larger_better: 是否是 monitor 的结果越大越好。 | ||||
:param format_json: 是否格式化 json 再打印 | :param format_json: 是否格式化 json 再打印 | ||||
""" | """ | ||||
@@ -140,10 +146,16 @@ class RawTextCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -7,7 +7,7 @@ __all__ = [ | |||||
'Evaluator' | 'Evaluator' | ||||
] | ] | ||||
from fastNLP.core.drivers import Driver | |||||
from fastNLP.core.drivers import Driver, TorchDriver | |||||
from ..drivers.choose_driver import choose_driver | from ..drivers.choose_driver import choose_driver | ||||
from .loops import Loop, EvaluateBatchLoop | from .loops import Loop, EvaluateBatchLoop | ||||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | ||||
@@ -316,15 +316,15 @@ class _MetricsWrapper: | |||||
raise TypeError("Parameter `metrics` can only be `Dict` type.") | raise TypeError("Parameter `metrics` can only be `Dict` type.") | ||||
for metric_name, metric in metrics.items(): | for metric_name, metric in metrics.items(): | ||||
# 因为 torchmetrics 是一个 nn.Module,因此我们需要先将其移到对应的机器上; | # 因为 torchmetrics 是一个 nn.Module,因此我们需要先将其移到对应的机器上; | ||||
if _is_torchmetrics_metric(metric): | |||||
if _is_torchmetrics_metric(metric) and isinstance(evaluator.driver, TorchDriver): | |||||
# torchmetrics 是默认自动开启了多卡的 | # torchmetrics 是默认自动开启了多卡的 | ||||
evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | ||||
elif isinstance(metric, Metric): | elif isinstance(metric, Metric): | ||||
# 如果数据是分布式的,但是不aggregate的话可能有问题 | # 如果数据是分布式的,但是不aggregate的话可能有问题 | ||||
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: | if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: | ||||
logger.warning_once( | |||||
"You have replace the sampler as distributed sampler when evaluation, but your " | |||||
f"metric {metric_name}:{metric.__class__.__name__}' `aggregate_when_get_metric` is False.") | |||||
logger.rank_zero_warning( | |||||
"You have replace the sampler as distributed sampler when evaluation, but your metric " | |||||
f"{metric_name}:{metric.__class__.__name__}'s `aggregate_when_get_metric` is False.", once=True) | |||||
if metric.aggregate_when_get_metric is None: | if metric.aggregate_when_get_metric is None: | ||||
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None | metric.aggregate_when_get_metric = evaluator._dist_sampler is not None | ||||
@@ -229,27 +229,19 @@ class Trainer(TrainerEventTrigger): | |||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 才让优化器迭代一次,默认为 1; | ||||
:param fp16: 是否开启混合精度训练,默认为 False; | :param fp16: 是否开启混合精度训练,默认为 False; | ||||
:param monitor: 对于一些特殊的 ``Callback``,例如 :class:`fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor`` | :param monitor: 对于一些特殊的 ``Callback``,例如 :class:`fastNLP.core.callbacks.CheckpointCallback`,它们需要参数 ``monitor`` | ||||
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。例如,对于 ``CheckpointCallback`` 而言,如果我们 | |||||
想要每隔一个 epoch 让 ``Evaluator`` 进行一次验证,然后保存训练以来的最好的结果;那么我们需要这样设置: | |||||
.. code-block:: | |||||
trainer = Trainer( | |||||
..., | |||||
metrics={'acc': accMetric()}, | |||||
callbacks=[CheckpointCallback( | |||||
..., | |||||
monitor='acc', | |||||
topk=1 | |||||
)] | |||||
) | |||||
这意味着对于 ``CheckpointCallback`` 来说,*'acc'* 就是一个监测的指标,用于在 ``Evaluator`` 验证后取出其需要监测的那个指标的值。 | |||||
``Trainer`` 中的参数 ``monitor`` 的作用在于为没有设置 ``monitor`` 参数但是需要该参数的 *callback* 实例设置该值。关于 ``monitor`` | |||||
参数更详细的说明,请见 :class:`fastNLP.core.callbacks.CheckpointCallback`; | |||||
注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | |||||
来从 ``Evaluator`` 的验证结果中获取当前评测的值,从而来判断是否执行一些特殊的操作。这里设置了 ``monitor`` 则所有的需要 | |||||
``monitor`` 但是没有自己设置的 ``Callback`` 都会使用这个值 | |||||
* 为 ``None`` | |||||
没有 monitor ,默认。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
注意该参数仅当传入了 ``evaluate_dataloaders`` 不为 ``None`` 时且有需要该参数但是没有设置该参数的 *Callback* 实例才有意义; | |||||
:param larger_better: 对于需要参数 ``monitor`` 的 *callback* 来说,``monitor`` 的值是否是越大越好;类似于 ``monitor``,其作用 | :param larger_better: 对于需要参数 ``monitor`` 的 *callback* 来说,``monitor`` 的值是否是越大越好;类似于 ``monitor``,其作用 | ||||
在于为没有设置 ``larger_better`` 参数但是需要该参数的 *callback* 实例设置该值; | 在于为没有设置 ``larger_better`` 参数但是需要该参数的 *callback* 实例设置该值; | ||||