Browse Source

1. 修复Tester测试过程中出现异常model不能重置为Training状态的bug; 2. FitlogCallback, EvaluateCallback在遭遇exception时直接raise,防止出现metric没有重新归零的问题

tags/v0.5.5
yh 5 years ago
parent
commit
abd10e24a8
2 changed files with 7 additions and 6 deletions
  1. +5
    -4
      fastNLP/core/callback.py
  2. +2
    -2
      fastNLP/core/tester.py

+ 5
- 4
fastNLP/core/callback.py View File

@@ -592,9 +592,10 @@ class FitlogCallback(Callback):
fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch) fitlog.add_metric(eval_result, name=key, step=self.step, epoch=self.epoch)
if better_result: if better_result:
fitlog.add_best_metric(eval_result, name=key) fitlog.add_best_metric(eval_result, name=key)
except Exception:
except Exception as e:
self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))
raise e

def on_train_end(self): def on_train_end(self):
fitlog.finish() fitlog.finish()
@@ -660,9 +661,9 @@ class EvaluateCallback(Callback):
eval_result = tester.test() eval_result = tester.test()
self.logger.info("EvaluateCallback evaluation on {}:".format(key)) self.logger.info("EvaluateCallback evaluation on {}:".format(key))
self.logger.info(tester._format_eval_results(eval_result)) self.logger.info(tester._format_eval_results(eval_result))
except Exception:
except Exception as e:
self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key)) self.logger.error("Exception happens when evaluate on DataSet named `{}`.".format(key))
raise e


class LRScheduler(Callback): class LRScheduler(Callback):
""" """


+ 2
- 2
fastNLP/core/tester.py View File

@@ -189,10 +189,10 @@ class Tester(object):
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y,
dataset=self.data, check_level=0) dataset=self.data, check_level=0)
finally:
self._mode(network, is_test=False)
if self.verbose >= 1: if self.verbose >= 1:
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results return eval_results
def _mode(self, model, is_test=False): def _mode(self, model, is_test=False):


Loading…
Cancel
Save