|
|
@@ -29,7 +29,7 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: |
|
|
|
callback.on_valid_end() # 可以进行在其它数据集上进行验证 |
|
|
|
callback.on_epoch_end() # epoch结束调用 |
|
|
|
callback.on_train_end() # 训练结束 |
|
|
|
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里 |
|
|
|
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里。 |
|
|
|
|
|
|
|
如下面的例子所示,我们可以使用内置的 callback 类,或者继承 :class:`~fastNLP.core.callback.Callback` |
|
|
|
定义自己的 callback 类:: |
|
|
@@ -64,7 +64,7 @@ __all__ = [ |
|
|
|
import os |
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
from copy import deepcopy |
|
|
|
try: |
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
|
@@ -73,7 +73,13 @@ except: |
|
|
|
tensorboardX_flag = False |
|
|
|
|
|
|
|
from ..io.model_io import ModelSaver, ModelLoader |
|
|
|
from .dataset import DataSet |
|
|
|
from .tester import Tester |
|
|
|
|
|
|
|
try: |
|
|
|
import fitlog |
|
|
|
except: |
|
|
|
pass |
|
|
|
|
|
|
|
class Callback(object): |
|
|
|
""" |
|
|
@@ -425,6 +431,90 @@ class EarlyStopCallback(Callback): |
|
|
|
else: |
|
|
|
raise exception # 抛出陌生Error |
|
|
|
|
|
|
|
class FitlogCallback(Callback): |
|
|
|
""" |
|
|
|
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` |
|
|
|
|
|
|
|
该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 |
|
|
|
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 |
|
|
|
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 |
|
|
|
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 |
|
|
|
|
|
|
|
:param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 |
|
|
|
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 |
|
|
|
dict的方式传入。如果仅传入DataSet, 则被命名为test |
|
|
|
:param Tester tester: Tester对象,将在on_valid_end时调用。tester中的会被命名为test |
|
|
|
:param int verbose: 是否在终端打印内容,0不打印 |
|
|
|
:param bool log_exception: fitlog是否记录发生的exception信息 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, data=None, tester=None, verbose=0, log_exception=False): |
|
|
|
super().__init__() |
|
|
|
self.datasets = {} |
|
|
|
self.testers = {} |
|
|
|
self._log_exception = log_exception |
|
|
|
if tester is not None: |
|
|
|
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." |
|
|
|
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." |
|
|
|
if data is not None: |
|
|
|
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." |
|
|
|
setattr(tester, 'verbose', 0) |
|
|
|
self.testers['test'] = tester |
|
|
|
|
|
|
|
if isinstance(data, dict): |
|
|
|
for key, value in data.items(): |
|
|
|
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." |
|
|
|
for key, value in data.items(): |
|
|
|
self.datasets[key] = value |
|
|
|
elif isinstance(data, DataSet): |
|
|
|
self.datasets['test'] = data |
|
|
|
else: |
|
|
|
raise TypeError("data receives dict[DataSet] or DataSet object.") |
|
|
|
|
|
|
|
self.verbose = verbose |
|
|
|
|
|
|
|
def on_train_begin(self): |
|
|
|
if (len(self.datasets)>0 or len(self.testers)>0 ) and self.trainer.dev_data is None: |
|
|
|
raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") |
|
|
|
|
|
|
|
if len(self.datasets)>0: |
|
|
|
for key, data in self.datasets.items(): |
|
|
|
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, |
|
|
|
verbose=0) |
|
|
|
self.testers[key] = tester |
|
|
|
fitlog.add_progress(total_steps=self.n_steps) |
|
|
|
|
|
|
|
def on_backward_begin(self, loss): |
|
|
|
fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) |
|
|
|
|
|
|
|
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): |
|
|
|
if better_result: |
|
|
|
eval_result = deepcopy(eval_result) |
|
|
|
eval_result['step'] = self.step |
|
|
|
eval_result['epoch'] = self.epoch |
|
|
|
fitlog.add_best_metric(eval_result) |
|
|
|
fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch) |
|
|
|
if len(self.testers)>0: |
|
|
|
for key, tester in self.testers.items(): |
|
|
|
try: |
|
|
|
eval_result = tester.test() |
|
|
|
if self.verbose!=0: |
|
|
|
self.pbar.write("Evaluation on DataSet {}:".format(key)) |
|
|
|
self.pbar.write(tester._format_eval_results(eval_result)) |
|
|
|
fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) |
|
|
|
if better_result: |
|
|
|
fitlog.add_best_metric(eval_result, name=key) |
|
|
|
except Exception: |
|
|
|
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) |
|
|
|
|
|
|
|
def on_train_end(self): |
|
|
|
fitlog.finish() |
|
|
|
|
|
|
|
def on_exception(self, exception): |
|
|
|
fitlog.finish(status=1) |
|
|
|
if self._log_exception: |
|
|
|
fitlog.add_other(str(exception), name='except_info') |
|
|
|
|
|
|
|
|
|
|
|
class LRScheduler(Callback): |
|
|
|
""" |
|
|
|