Browse Source

增加一个fitlog callback,实现与fitlog实验记录

tags/v0.4.10
yh_cc 5 years ago
parent
commit
6ad85a823b
1 changed files with 92 additions and 2 deletions
  1. +92
    -2
      fastNLP/core/callback.py

+ 92
- 2
fastNLP/core/callback.py View File

@@ -29,7 +29,7 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:
callback.on_valid_end() # 可以进行在其它数据集上进行验证 callback.on_valid_end() # 可以进行在其它数据集上进行验证
callback.on_epoch_end() # epoch结束调用 callback.on_epoch_end() # epoch结束调用
callback.on_train_end() # 训练结束 callback.on_train_end() # 训练结束
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里


如下面的例子所示,我们可以使用内置的 callback 类,或者继承 :class:`~fastNLP.core.callback.Callback` 如下面的例子所示,我们可以使用内置的 callback 类,或者继承 :class:`~fastNLP.core.callback.Callback`
定义自己的 callback 类:: 定义自己的 callback 类::
@@ -64,7 +64,7 @@ __all__ = [
import os import os


import torch import torch
from copy import deepcopy
try: try:
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
@@ -73,7 +73,13 @@ except:
tensorboardX_flag = False tensorboardX_flag = False


from ..io.model_io import ModelSaver, ModelLoader from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet
from .tester import Tester


try:
import fitlog
except:
pass


class Callback(object): class Callback(object):
""" """
@@ -425,6 +431,90 @@ class EarlyStopCallback(Callback):
else: else:
raise exception # 抛出陌生Error raise exception # 抛出陌生Error


class FitlogCallback(Callback):
"""
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback`

该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。

:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过
dict的方式传入。如果仅传入DataSet, 则被命名为test
:param Tester tester: Tester对象,将在on_valid_end时调用。tester中的会被命名为test
:param int verbose: 是否在终端打印内容,0不打印
:param bool log_exception: fitlog是否记录发生的exception信息
"""

def __init__(self, data=None, tester=None, verbose=0, log_exception=False):
super().__init__()
self.datasets = {}
self.testers = {}
self._log_exception = log_exception
if tester is not None:
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed."
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data."
if data is not None:
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed."
setattr(tester, 'verbose', 0)
self.testers['test'] = tester

if isinstance(data, dict):
for key, value in data.items():
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}."
for key, value in data.items():
self.datasets[key] = value
elif isinstance(data, DataSet):
self.datasets['test'] = data
else:
raise TypeError("data receives dict[DataSet] or DataSet object.")

self.verbose = verbose

def on_train_begin(self):
if (len(self.datasets)>0 or len(self.testers)>0 ) and self.trainer.dev_data is None:
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.")

if len(self.datasets)>0:
for key, data in self.datasets.items():
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics,
verbose=0)
self.testers[key] = tester
fitlog.add_progress(total_steps=self.n_steps)

def on_backward_begin(self, loss):
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch)

def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
if better_result:
eval_result = deepcopy(eval_result)
eval_result['step'] = self.step
eval_result['epoch'] = self.epoch
fitlog.add_best_metric(eval_result)
fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch)
if len(self.testers)>0:
for key, tester in self.testers.items():
try:
eval_result = tester.test()
if self.verbose!=0:
self.pbar.write("Evaluation on DataSet {}:".format(key))
self.pbar.write(tester._format_eval_results(eval_result))
fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch)
if better_result:
fitlog.add_best_metric(eval_result, name=key)
except Exception:
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))

def on_train_end(self):
fitlog.finish()

def on_exception(self, exception):
fitlog.finish(status=1)
if self._log_exception:
fitlog.add_other(str(exception), name='except_info')



class LRScheduler(Callback): class LRScheduler(Callback):
""" """


Loading…
Cancel
Save