diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index 7fad2d0b..c65f2a56 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -435,7 +435,7 @@ class FitlogCallback(Callback): """ 别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` - 该callback将loss和progress自动写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 + 该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 @@ -444,15 +444,18 @@ class FitlogCallback(Callback): DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 dict的方式传入。如果仅传入DataSet, 则被命名为test :param Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` - :param int verbose: 是否在终端打印内容,0不打印 + :param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 + 大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 + :param int verbose: 是否在终端打印evaluation的结果,0不打印。 :param bool log_exception: fitlog是否记录发生的exception信息 """ - def __init__(self, data=None, tester=None, verbose=0, log_exception=False): + def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): super().__init__() self.datasets = {} self.testers = {} self._log_exception = log_exception + assert isinstance(log_loss_every, int) and log_loss_every>=0 if tester is not None: assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." @@ -472,6 +475,8 @@ class FitlogCallback(Callback): raise TypeError("data receives dict[DataSet] or DataSet object.") self.verbose = verbose + self._log_loss_every = log_loss_every + self._avg_loss = 0 def on_train_begin(self): if (len(self.datasets)>0 or len(self.testers)>0 ) and self.trainer.dev_data is None: @@ -485,7 +490,11 @@ class FitlogCallback(Callback): fitlog.add_progress(total_steps=self.n_steps) def on_backward_begin(self, loss): - fitlog.add_loss(loss.item(), name='loss', step=self.step, epoch=self.epoch) + if self._log_loss_every>0: + self._avg_loss += loss.item() + if self.step%self._log_loss_every==0: + fitlog.add_loss(self._avg_loss/self._log_loss_every, name='loss', step=self.step, epoch=self.epoch) + self._avg_loss = 0 def on_valid_end(self, eval_result, metric_key, optimizer, better_result): if better_result: @@ -513,7 +522,7 @@ class FitlogCallback(Callback): def on_exception(self, exception): fitlog.finish(status=1) if self._log_exception: - fitlog.add_other(str(exception), name='except_info') + fitlog.add_other(repr(exception), name='except_info') class LRScheduler(Callback): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 702cb6e7..3b1a8bf5 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -494,12 +494,14 @@ class Trainer(object): self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) - def train(self, load_best_model=True): + def train(self, load_best_model=True, on_exception='ignore'): """ 使用该函数使Trainer开始训练。 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 最好的模型参数。 + :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 + 支持'ignore'与'raise': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出。 :return dict: 返回一个字典类型的数据, 内含以下内容:: @@ -527,8 +529,10 @@ class Trainer(object): self.callback_manager.on_train_begin() self._train() self.callback_manager.on_train_end() - except (CallbackException, KeyboardInterrupt) as e: + except (CallbackException, KeyboardInterrupt, Exception) as e: self.callback_manager.on_exception(e) + if on_exception=='raise': + raise e if self.dev_data is not None and hasattr(self, 'best_dev_perf'): print(