diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 57a31a69..d7694e00 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -532,7 +532,7 @@ class Trainer(object): self._train() self.callback_manager.on_train_end() - except Exception as e: + except BaseException as e: self.callback_manager.on_exception(e) if on_exception == 'auto': if not isinstance(e, (CallbackException, KeyboardInterrupt)): diff --git a/test/core/test_callbacks.py b/test/core/test_callbacks.py index e2aa5fa4..71a5565d 100644 --- a/test/core/test_callbacks.py +++ b/test/core/test_callbacks.py @@ -66,8 +66,7 @@ class TestCallback(unittest.TestCase): dev_data=data_set, metrics=AccuracyMetric(pred="predict", target="y"), callbacks=[EarlyStopCallback(5)]) - with self.assertRaises(EarlyStopError): - trainer.train() + trainer.train() def test_lr_scheduler(self): data_set, model = prepare_env()