diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index 3f3691e3..7252398c 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -4,8 +4,6 @@ from types import DynamicClassAttribute from functools import wraps -import fastNLP - __all__ = [ 'Events', 'EventsList', diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index f63c6088..2b8fff60 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -11,6 +11,7 @@ from .callback import Callback from fastNLP.core.log import logger from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.envs import rank_zero_call +from fastNLP.core.utils.utils import _get_fun_msg def _transfer(func): @@ -21,10 +22,12 @@ def _transfer(func): def wrapper(manager, *arg, **kwargs): manager.callback_counter[func.__name__] += 1 # 给实际被调用的 callback_fn 的计数加 1; - returns = [] for callback_fn in manager.callback_fns[func.__name__]: - returns.append(callback_fn(*arg, **kwargs)) - return returns + try: + callback_fn(*arg, **kwargs) + except BaseException as e: + logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") + raise e return wrapper diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index 46feafa5..543c0c57 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -11,7 +11,6 @@ from paddle.io import DataLoader from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset -from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context @dataclass diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index d8dd7d73..891626b5 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -100,17 +100,16 @@ def model_and_optimizers(request): # 测试一下普通的情况; @pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) -@pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) @pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) @magic_argv_env_context def test_trainer_torch_with_evaluator( model_and_optimizers: TrainerParameters, driver, device, - callbacks, evaluate_every, n_epochs=10, ): + callbacks = [RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)] trainer = Trainer( model=model_and_optimizers.model, driver=driver, @@ -172,7 +171,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( if dist.is_initialized(): dist.destroy_process_group() - +@pytest.mark.torch @pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_validate_every( @@ -184,9 +183,7 @@ def test_trainer_validate_every( def validate_every(trainer): if trainer.global_forward_batches % 10 == 0: - print(trainer) print("\nfastNLP test validate every.\n") - print(trainer.global_forward_batches) return True trainer = Trainer( diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index 4fd5b654..1e0d0e11 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -36,12 +36,10 @@ class RecordMetricCallback(Callback): self.larger_better = larger_better self.metric = None self.metric_threshold = metric_threshold - self.metric_begin_value = None + self.metric_begin_value = float('-inf') if larger_better else float('inf') def on_evaluate_end(self, trainer, results): self.metric = results[self.monitor] - if self.metric_begin_value is None: - self.metric_begin_value = self.metric def on_train_end(self, trainer): if self.larger_better: diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 7e02ca0d..463f144d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -30,12 +30,12 @@ def recover_logger(fn): return wrapper -def magic_argv_env_context(fn=None, timeout=600): +def magic_argv_env_context(fn=None, timeout=300): """ 用来在测试时包裹每一个单独的测试函数,使得 ddp 测试正确; 会丢掉 pytest 中的 arg 参数。 - :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 10 分钟,单位为秒; + :param timeout: 表示一个测试如果经过多久还没有通过的话就主动将其 kill 掉,默认为 5 分钟,单位为秒; :return: """ # 说明是通过 @magic_argv_env_context(timeout=600) 调用;