diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index ed1b0697..cd97b64b 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -29,7 +29,7 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: callback.on_valid_end() # 可以进行在其它数据集上进行验证 callback.on_epoch_end() # epoch结束调用 callback.on_train_end() # 训练结束 - callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里 + callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里。 如下面的例子所示,我们可以使用内置的 callback 类,或者继承 :class:`~fastNLP.core.callback.Callback` 定义自己的 callback 类:: @@ -64,7 +64,7 @@ __all__ = [ import os import torch - +from copy import deepcopy try: from tensorboardX import SummaryWriter @@ -73,7 +73,13 @@ except: tensorboardX_flag = False from ..io.model_io import ModelSaver, ModelLoader +from .dataset import DataSet +from .tester import Tester +try: + import fitlog +except: + pass class Callback(object): """ @@ -425,6 +431,90 @@ class EarlyStopCallback(Callback): else: raise exception # 抛出陌生Error +class FitlogCallback(Callback): + """ + 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` + + 该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 + 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 + 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 + fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 + + :param DataSet,dict(DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 + DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 + dict的方式传入。如果仅传入DataSet, 则被命名为test + :param Tester tester: Tester对象,将在on_valid_end时调用。tester中的会被命名为test + :param int verbose: 是否在终端打印内容,0不打印 + :param bool log_exception: fitlog是否记录发生的exception信息 + """ + + def __init__(self, data=None, tester=None, verbose=0, log_exception=False): + super().__init__() + self.datasets = {} + self.testers = {} + self._log_exception = log_exception + if tester is not None: + assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." + assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." + if data is not None: + assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." + setattr(tester, 'verbose', 0) + self.testers['test'] = tester + + if isinstance(data, dict): + for key, value in data.items(): + assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." + for key, value in data.items(): + self.datasets[key] = value + elif isinstance(data, DataSet): + self.datasets['test'] = data + else: + raise TypeError("data receives dict[DataSet] or DataSet object.") + + self.verbose = verbose + + def on_train_begin(self): + if (len(self.datasets)>0 or len(self.testers)>0 ) and self.trainer.dev_data is None: + raise RuntimeError("Trainer has no dev data, you cannot pass extra data to do evaluation.") + + if len(self.datasets)>0: + for key, data in self.datasets.items(): + tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, + verbose=0) + self.testers[key] = tester + fitlog.add_progress(total_steps=self.n_steps) + + def on_backward_begin(self, loss): + fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) + + def on_valid_end(self, eval_result, metric_key, optimizer, better_result): + if better_result: + eval_result = deepcopy(eval_result) + eval_result['step'] = self.step + eval_result['epoch'] = self.epoch + fitlog.add_best_metric(eval_result) + fitlog.add_metric(eval_result, step=self.step, epoch=self.epoch) + if len(self.testers)>0: + for key, tester in self.testers.items(): + try: + eval_result = tester.test() + if self.verbose!=0: + self.pbar.write("Evaluation on DataSet {}:".format(key)) + self.pbar.write(tester._format_eval_results(eval_result)) + fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) + if better_result: + fitlog.add_best_metric(eval_result, name=key) + except Exception: + self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) + + def on_train_end(self): + fitlog.finish() + + def on_exception(self, exception): + fitlog.finish(status=1) + if self._log_exception: + fitlog.add_other(str(exception), name='except_info') + class LRScheduler(Callback): """