From fcfd8c267ed7f82eb4e1ee8826e8c7fe7567abbb Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 2 May 2022 11:03:19 +0000 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9callback=5Fmanager=E7=9A=84?= =?UTF-8?q?=E6=8A=A5=E9=94=99=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_manager.py | 9 ++++++--- tests/helpers/callbacks/helper_callbacks.py | 4 +--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index f63c6088..2b8fff60 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -11,6 +11,7 @@ from .callback import Callback from fastNLP.core.log import logger from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.envs import rank_zero_call +from fastNLP.core.utils.utils import _get_fun_msg def _transfer(func): @@ -21,10 +22,12 @@ def _transfer(func): def wrapper(manager, *arg, **kwargs): manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; - returns = [] for callback_fn in manager.callback_fns[func.__name__]: - returns.append(callback_fn(*arg, **kwargs)) - return returns + try: + callback_fn(*arg, **kwargs) + except BaseException as e: + logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") + raise e return wrapper diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index 4fd5b654..1e0d0e11 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -36,12 +36,10 @@ class RecordMetricCallback(Callback): self.larger_better = larger_better self.metric = None self.metric_threshold = metric_threshold - self.metric_begin_value = None + self.metric_begin_value = float('-inf') if larger_better else float('inf') def on_evaluate_end(self, trainer, results): self.metric = results[self.monitor] - if self.metric_begin_value is None: - self.metric_begin_value = self.metric def on_train_end(self, trainer): if self.larger_better: