Browse Source

trainer增加对evaluate结果的check

tags/v0.2.0^2
yh 6 years ago
parent
commit
a05ffd31cd
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      fastNLP/core/trainer.py

+ 4
- 4
fastNLP/core/trainer.py View File

@@ -48,7 +48,7 @@ class Trainer(object):
:param use_cuda: :param use_cuda:
:param str save_path: file path to save models :param str save_path: file path to save models
:param Optimizer optimizer: an optimizer object :param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict.
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add a `-` character in front of the string. For example smaller, add a `-` character in front of the string. For example
@@ -91,7 +91,7 @@ class Trainer(object):


if check_code_level > -1: if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
check_level=check_code_level)
metric_key=metric_key, check_level=check_code_level)


self.train_data = train_data self.train_data = train_data
self.dev_data = dev_data # If None, No validation. self.dev_data = dev_data # If None, No validation.
@@ -294,7 +294,7 @@ DEFAULT_CHECK_NUM_BATCH = 2




def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
dev_data=None, metric_key=None,
check_level=0): check_level=0):
# check get_loss 方法 # check get_loss 方法
model_devcie = model.parameters().__next__().device model_devcie = model.parameters().__next__().device
@@ -340,7 +340,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1) batch_size=batch_size, verbose=-1)
evaluate_results = tester.test() evaluate_results = tester.test()
# TODO 这里需要检查是否返回来的值是否是合理的
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics)




def _check_eval_results(metrics, metric_key, metric_list): def _check_eval_results(metrics, metric_key, metric_list):


Loading…
Cancel
Save