From dbb16d8ba4b64c8162a7a496d85da9c0e9307649 Mon Sep 17 00:00:00 2001 From: yhcc Date: Mon, 10 Oct 2022 16:38:20 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dfitlog=E5=9C=A8raise=20except?= =?UTF-8?q?ion=E7=9A=84=E6=97=B6=E5=80=99=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/__init__.py | 5 ++++- fastNLP/core/callbacks/fitlog_callback.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index feff9f9b..ac34d5ee 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -24,7 +24,9 @@ __all__ = [ "FitlogCallback", - "TimerCallback" + "TimerCallback", + + "TopkSaver" ] @@ -41,3 +43,4 @@ from .more_evaluate_callback import MoreEvaluateCallback from .has_monitor_callback import ResultsMonitor, HasMonitorCallback from .fitlog_callback import FitlogCallback from .timer_callback import TimerCallback +from .topk_saver import TopkSaver diff --git a/fastNLP/core/callbacks/fitlog_callback.py b/fastNLP/core/callbacks/fitlog_callback.py index a7716fa6..44430b67 100644 --- a/fastNLP/core/callbacks/fitlog_callback.py +++ b/fastNLP/core/callbacks/fitlog_callback.py @@ -40,6 +40,7 @@ class FitlogCallback(HasMonitorCallback): self.log_exception = log_exception self.log_loss_every = log_loss_every self.avg_loss = 0 + self.catch_exception = False def on_after_trainer_initialized(self, trainer, driver): if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog @@ -72,9 +73,11 @@ class FitlogCallback(HasMonitorCallback): self.avg_loss = 0 def on_train_end(self, trainer): - fitlog.finish() + if not self.catch_exception: + fitlog.finish() def on_exception(self, trainer, exception): + self.catch_exception = True fitlog.finish(status=1) if self.log_exception: fitlog.add_other(repr(exception), name='except_info')