From a05ffd31cd07f5ebce511260ec086d406c47d332 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 2 Dec 2018 12:55:15 +0800 Subject: [PATCH] =?UTF-8?q?trainer=E5=A2=9E=E5=8A=A0=E5=AF=B9evaluate?= =?UTF-8?q?=E7=BB=93=E6=9E=9C=E7=9A=84check?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 78a26334..2c57057f 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -48,7 +48,7 @@ class Trainer(object): :param use_cuda: :param str save_path: file path to save models :param Optimizer optimizer: an optimizer object - :param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict. + :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets smaller, add a `-` character in front of the string. For example @@ -91,7 +91,7 @@ class Trainer(object): if check_code_level > -1: _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, - check_level=check_code_level) + metric_key=metric_key, check_level=check_code_level) self.train_data = train_data self.dev_data = dev_data # If None, No validation. @@ -294,7 +294,7 @@ DEFAULT_CHECK_NUM_BATCH = 2 def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, - dev_data=None, + dev_data=None, metric_key=None, check_level=0): # check get_loss 方法 model_devcie = model.parameters().__next__().device @@ -340,7 +340,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) evaluate_results = tester.test() - # TODO 这里需要检查是否返回来的值是否是合理的 + _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) def _check_eval_results(metrics, metric_key, metric_list):