Browse Source

Trainer Update:

* 添加初始化注释
* 从_better_eval_result中抽取check metrics的逻辑到_check_eval_results函数
tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
d74901e037
1 changed files with 78 additions and 45 deletions
  1. +78
    -45
      fastNLP/core/trainer.py

+ 78
- 45
fastNLP/core/trainer.py View File

@@ -8,20 +8,21 @@ from tensorboardX import SummaryWriter
from torch import nn from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature



class Trainer(object): class Trainer(object):
"""Main Training Loop """Main Training Loop


@@ -33,6 +34,30 @@ class Trainer(object):
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None, metric_key=None,
**kwargs): **kwargs):
"""

:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param LossBase losser: a loss object
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data
:param use_cuda:
:param str save_path: file path to save models
:param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict.
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add a `-` character in front of the string. For example
::
metric_key="-PPL" # language model gets better as perplexity gets smaller

:param kwargs:

"""
super(Trainer, self).__init__() super(Trainer, self).__init__()


if not isinstance(train_data, DataSet): if not isinstance(train_data, DataSet):
@@ -64,7 +89,7 @@ class Trainer(object):
# prepare loss # prepare loss
losser = _prepare_losser(losser) losser = _prepare_losser(losser)


if check_code_level>-1:
if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
check_level=check_code_level) check_level=check_code_level)


@@ -245,52 +270,29 @@ class Trainer(object):


:return bool value: True means current results on dev set is the best. :return bool value: True means current results on dev set is the best.
""" """
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else: else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else: else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better
is_better = False
return is_better




DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2 DEFAULT_CHECK_NUM_BATCH = 2



def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, dev_data=None,
check_level=0): check_level=0):
@@ -341,3 +343,34 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# TODO 这里需要检查是否返回来的值是否是合理的 # TODO 这里需要检查是否返回来的值是否是合理的




def _check_eval_results(metrics, metric_key, metric_list):
# metrics: tester返回的结果
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化
# metric_list: 多个用来做评价的指标,来自Trainer的初始化
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = metric_list[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if metric_key not in metric_dict:
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}")
indicator_val = metric_dict[metric_key]
else:
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics)))
return indicator_val

Loading…
Cancel
Save