Browse Source

Trainer Update:

* 添加初始化注释
* 从_better_eval_result中抽取check metrics的逻辑到_check_eval_results函数
tags/v0.2.0^2
FengZiYjun 6 years ago
parent
commit
d74901e037
1 changed files with 78 additions and 45 deletions
  1. +78
    -45
      fastNLP/core/trainer.py

+ 78
- 45
fastNLP/core/trainer.py View File

@@ -8,20 +8,21 @@ from tensorboardX import SummaryWriter
from torch import nn

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_forward_error
from fastNLP.core.utils import _check_loss_evaluate
from fastNLP.core.utils import _move_dict_value_to_device
from fastNLP.core.utils import get_func_signature


class Trainer(object):
"""Main Training Loop

@@ -33,6 +34,30 @@ class Trainer(object):
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None,
**kwargs):
"""

:param DataSet train_data: the training data
:param torch.nn.modules.module model: a PyTorch model
:param LossBase losser: a loss object
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics
:param int n_epochs: the number of training epochs
:param int batch_size: batch size for training and validation
:param int print_every: step interval to print next training information. Default: -1(no print).
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch).
:param DataSet dev_data: the validation data
:param use_cuda:
:param str save_path: file path to save models
:param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict.
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add a `-` character in front of the string. For example
::
metric_key="-PPL" # language model gets better as perplexity gets smaller

:param kwargs:

"""
super(Trainer, self).__init__()

if not isinstance(train_data, DataSet):
@@ -64,7 +89,7 @@ class Trainer(object):
# prepare loss
losser = _prepare_losser(losser)

if check_code_level>-1:
if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
check_level=check_code_level)

@@ -245,52 +270,29 @@ class Trainer(object):

:return bool value: True means current results on dev set is the best.
"""
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics)
is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better
is_better = False
return is_better


DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2


def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None,
check_level=0):
@@ -341,3 +343,34 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
# TODO 这里需要检查是否返回来的值是否是合理的


def _check_eval_results(metrics, metric_key, metric_list):
# metrics: tester返回的结果
# metric_key: 一个用来做筛选的指标,来自Trainer的初始化
# metric_list: 多个用来做评价的指标,来自Trainer的初始化
if isinstance(metrics, tuple):
loss, metrics = metrics

if isinstance(metrics, dict):
if len(metrics) == 1:
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else:
metrics_name = metric_list[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if metric_key not in metric_dict:
raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}")
indicator_val = metric_dict[metric_key]
else:
raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics)))
return indicator_val

Loading…
Cancel
Save