|
@@ -3,30 +3,31 @@ __all__ = [ |
|
|
] |
|
|
] |
|
|
from .has_monitor_callback import HasMonitorCallback |
|
|
from .has_monitor_callback import HasMonitorCallback |
|
|
from ...envs import _module_available |
|
|
from ...envs import _module_available |
|
|
|
|
|
from ...envs import get_global_rank |
|
|
if _module_available('fitlog'): |
|
|
if _module_available('fitlog'): |
|
|
import fitlog |
|
|
import fitlog |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FitlogCallback(HasMonitorCallback): |
|
|
class FitlogCallback(HasMonitorCallback): |
|
|
|
|
|
""" |
|
|
|
|
|
自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 |
|
|
|
|
|
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。 |
|
|
|
|
|
|
|
|
|
|
|
:param monitor: 监控的 metric 值。 |
|
|
|
|
|
|
|
|
|
|
|
* 为 ``None`` |
|
|
|
|
|
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 |
|
|
|
|
|
* 为 ``str`` |
|
|
|
|
|
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 |
|
|
|
|
|
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 |
|
|
|
|
|
* 为 ``Callable`` |
|
|
|
|
|
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 |
|
|
|
|
|
的 ``monitor`` 值请返回 ``None`` 。 |
|
|
|
|
|
:param larger_better: 是否是越大越好。 |
|
|
|
|
|
:param log_exception: 是否记录 ``exception`` 。 |
|
|
|
|
|
:param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 |
|
|
|
|
|
""" |
|
|
def __init__(self, monitor=None, larger_better: bool = True, log_exception:bool=True, log_loss_every:int=0): |
|
|
def __init__(self, monitor=None, larger_better: bool = True, log_exception:bool=True, log_loss_every:int=0): |
|
|
""" |
|
|
|
|
|
自动记录 ``evaluation`` 结果到 ``fitlog`` 中的 ``Callback`` 。会根据 ``monitor`` 记录最好的结果,以及每一次 ``evaluate`` 后的 |
|
|
|
|
|
结果。 |
|
|
|
|
|
|
|
|
|
|
|
:param monitor: 监控的 metric 值。 |
|
|
|
|
|
|
|
|
|
|
|
* 为 ``None`` |
|
|
|
|
|
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 |
|
|
|
|
|
* 为 ``str`` |
|
|
|
|
|
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 |
|
|
|
|
|
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 |
|
|
|
|
|
* 为 ``Callable`` |
|
|
|
|
|
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 |
|
|
|
|
|
的 ``monitor`` 值请返回 ``None`` 。 |
|
|
|
|
|
:param larger_better: 是否是越大越好。 |
|
|
|
|
|
:param log_exception: 是否记录 ``exception`` 。 |
|
|
|
|
|
:param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 |
|
|
|
|
|
""" |
|
|
|
|
|
assert _module_available('fitlog'), "fitlog is not installed." |
|
|
assert _module_available('fitlog'), "fitlog is not installed." |
|
|
|
|
|
|
|
|
super().__init__(monitor=monitor, larger_better=larger_better) |
|
|
super().__init__(monitor=monitor, larger_better=larger_better) |
|
@@ -34,6 +35,10 @@ class FitlogCallback(HasMonitorCallback): |
|
|
self.log_loss_every = log_loss_every |
|
|
self.log_loss_every = log_loss_every |
|
|
self.avg_loss = 0 |
|
|
self.avg_loss = 0 |
|
|
|
|
|
|
|
|
|
|
|
def on_after_trainer_initialized(self, trainer, driver): |
|
|
|
|
|
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog |
|
|
|
|
|
fitlog.debug() |
|
|
|
|
|
|
|
|
def on_evaluate_end(self, trainer, results): |
|
|
def on_evaluate_end(self, trainer, results): |
|
|
results = self.itemize_results(results) |
|
|
results = self.itemize_results(results) |
|
|
fitlog.add_metric(results, step=trainer.global_forward_batches, epoch=trainer.cur_epoch_idx) |
|
|
fitlog.add_metric(results, step=trainer.global_forward_batches, epoch=trainer.cur_epoch_idx) |
|
|