Browse Source

修改callback_manager的报错信息

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
fcfd8c267e
2 changed files with 7 additions and 6 deletions
  1. +6
    -3
      fastNLP/core/callbacks/callback_manager.py
  2. +1
    -3
      tests/helpers/callbacks/helper_callbacks.py

+ 6
- 3
fastNLP/core/callbacks/callback_manager.py View File

@@ -11,6 +11,7 @@ from .callback import Callback
from fastNLP.core.log import logger
from .progress_callback import ProgressCallback, choose_progress_callback
from fastNLP.envs import rank_zero_call
from fastNLP.core.utils.utils import _get_fun_msg


def _transfer(func):
@@ -21,10 +22,12 @@ def _transfer(func):

def wrapper(manager, *arg, **kwargs):
manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1;
returns = []
for callback_fn in manager.callback_fns[func.__name__]:
returns.append(callback_fn(*arg, **kwargs))
return returns
try:
callback_fn(*arg, **kwargs)
except BaseException as e:
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.")
raise e
return wrapper




+ 1
- 3
tests/helpers/callbacks/helper_callbacks.py View File

@@ -36,12 +36,10 @@ class RecordMetricCallback(Callback):
self.larger_better = larger_better
self.metric = None
self.metric_threshold = metric_threshold
self.metric_begin_value = None
self.metric_begin_value = float('-inf') if larger_better else float('inf')

def on_evaluate_end(self, trainer, results):
self.metric = results[self.monitor]
if self.metric_begin_value is None:
self.metric_begin_value = self.metric

def on_train_end(self, trainer):
if self.larger_better:


Loading…
Cancel
Save