diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index feff9f9b..ac34d5ee 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -24,7 +24,9 @@ __all__ = [ "FitlogCallback", - "TimerCallback" + "TimerCallback", + + "TopkSaver" ] @@ -41,3 +43,4 @@ from .more_evaluate_callback import MoreEvaluateCallback from .has_monitor_callback import ResultsMonitor, HasMonitorCallback from .fitlog_callback import FitlogCallback from .timer_callback import TimerCallback +from .topk_saver import TopkSaver diff --git a/fastNLP/core/callbacks/fitlog_callback.py b/fastNLP/core/callbacks/fitlog_callback.py index a7716fa6..44430b67 100644 --- a/fastNLP/core/callbacks/fitlog_callback.py +++ b/fastNLP/core/callbacks/fitlog_callback.py @@ -40,6 +40,7 @@ class FitlogCallback(HasMonitorCallback): self.log_exception = log_exception self.log_loss_every = log_loss_every self.avg_loss = 0 + self.catch_exception = False def on_after_trainer_initialized(self, trainer, driver): if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog @@ -72,9 +73,11 @@ class FitlogCallback(HasMonitorCallback): self.avg_loss = 0 def on_train_end(self, trainer): - fitlog.finish() + if not self.catch_exception: + fitlog.finish() def on_exception(self, trainer, exception): + self.catch_exception = True fitlog.finish(status=1) if self.log_exception: fitlog.add_other(repr(exception), name='except_info')