diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py index ad417340..095ebc3d 100644 --- a/fastNLP/core/callback.py +++ b/fastNLP/core/callback.py @@ -592,9 +592,10 @@ class FitlogCallback(Callback): fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) if better_result: fitlog.add_best_metric(eval_result, name=key) - except Exception: + except Exception as e: self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) - + raise e + def on_train_end(self): fitlog.finish() @@ -660,9 +661,9 @@ class EvaluateCallback(Callback): eval_result = tester.test() self.logger.info("EvaluateCallback evaluation on {}:".format(key)) self.logger.info(tester._format_eval_results(eval_result)) - except Exception: + except Exception as e: self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key)) - + raise e class LRScheduler(Callback): """ diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index e92eb422..8e2ac8a7 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -189,10 +189,10 @@ class Tester(object): _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, dataset=self.data, check_level=0) - + finally: + self._mode(network, is_test=False) if self.verbose >= 1: logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) - self._mode(network, is_test=False) return eval_results def _mode(self, model, is_test=False):