|
@@ -11,6 +11,7 @@ from .callback import Callback |
|
|
from fastNLP.core.log import logger |
|
|
from fastNLP.core.log import logger |
|
|
from .progress_callback import ProgressCallback, choose_progress_callback |
|
|
from .progress_callback import ProgressCallback, choose_progress_callback |
|
|
from fastNLP.envs import rank_zero_call |
|
|
from fastNLP.envs import rank_zero_call |
|
|
|
|
|
from fastNLP.core.utils.utils import _get_fun_msg |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _transfer(func): |
|
|
def _transfer(func): |
|
@@ -21,10 +22,12 @@ def _transfer(func): |
|
|
|
|
|
|
|
|
def wrapper(manager, *arg, **kwargs): |
|
|
def wrapper(manager, *arg, **kwargs): |
|
|
manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; |
|
|
manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; |
|
|
returns = [] |
|
|
|
|
|
for callback_fn in manager.callback_fns[func.__name__]: |
|
|
for callback_fn in manager.callback_fns[func.__name__]: |
|
|
returns.append(callback_fn(*arg, **kwargs)) |
|
|
|
|
|
return returns |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
callback_fn(*arg, **kwargs) |
|
|
|
|
|
except BaseException as e: |
|
|
|
|
|
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") |
|
|
|
|
|
raise e |
|
|
return wrapper |
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|