From d74901e0379ea8cf78dd62c6f2bfaf40dee9facf Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sun, 2 Dec 2018 11:36:35 +0800 Subject: [PATCH] =?UTF-8?q?Trainer=20Update:=20*=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=B3=A8=E9=87=8A=20*=20=E4=BB=8E?= =?UTF-8?q?=5Fbetter=5Feval=5Fresult=E4=B8=AD=E6=8A=BD=E5=8F=96check=20met?= =?UTF-8?q?rics=E7=9A=84=E9=80=BB=E8=BE=91=E5=88=B0=5Fcheck=5Feval=5Fresul?= =?UTF-8?q?ts=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 123 +++++++++++++++++++++++++--------------- 1 file changed, 78 insertions(+), 45 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2a5a59e4..78a26334 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -8,20 +8,21 @@ from tensorboardX import SummaryWriter from torch import nn from fastNLP.core.batch import Batch +from fastNLP.core.dataset import DataSet +from fastNLP.core.losses import _prepare_losser +from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.optimizer import Adam from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester -from fastNLP.core.dataset import DataSet -from fastNLP.core.losses import _prepare_losser -from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.utils import CheckError -from fastNLP.core.utils import _check_loss_evaluate -from fastNLP.core.utils import _check_forward_error from fastNLP.core.utils import _build_args +from fastNLP.core.utils import _check_forward_error +from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import get_func_signature + class Trainer(object): """Main Training Loop @@ -33,6 +34,30 @@ class Trainer(object): optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, metric_key=None, **kwargs): + """ + + :param DataSet train_data: the training data + :param torch.nn.modules.module model: a PyTorch model + :param LossBase losser: a loss object + :param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics + :param int n_epochs: the number of training epochs + :param int batch_size: batch size for training and validation + :param int print_every: step interval to print next training information. Default: -1(no print). + :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). + :param DataSet dev_data: the validation data + :param use_cuda: + :param str save_path: file path to save models + :param Optimizer optimizer: an optimizer object + :param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict. + :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one + of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets + smaller, add a `-` character in front of the string. For example + :: + metric_key="-PPL" # language model gets better as perplexity gets smaller + + :param kwargs: + + """ super(Trainer, self).__init__() if not isinstance(train_data, DataSet): @@ -64,7 +89,7 @@ class Trainer(object): # prepare loss losser = _prepare_losser(losser) - if check_code_level>-1: + if check_code_level > -1: _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, check_level=check_code_level) @@ -245,52 +270,29 @@ class Trainer(object): :return bool value: True means current results on dev set is the best. """ - if isinstance(metrics, tuple): - loss, metrics = metrics - - if isinstance(metrics, dict): - if len(metrics) == 1: - # only single metric, just use it - metric_dict = list(metrics.values())[0] - metrics_name = list(metrics.keys())[0] - else: - metrics_name = self.metrics[0].__class__.__name__ - if metrics_name not in metrics: - raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") - metric_dict = metrics[metrics_name] - - if len(metric_dict) == 1: - indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] - elif len(metric_dict) > 1 and self.metric_key is None: - raise RuntimeError( - f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") - else: - # metric_key is set - if self.metric_key not in metric_dict: - raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}") - indicator_val = metric_dict[self.metric_key] - - is_better = True - if self.best_metric_indicator is None: - # first-time validation - self.best_metric_indicator = indicator_val + indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) + is_better = True + if self.best_metric_indicator is None: + # first-time validation + self.best_metric_indicator = indicator_val + else: + if self.increase_better is True: + if indicator_val > self.best_metric_indicator: + self.best_metric_indicator = indicator_val + else: + is_better = False else: - if self.increase_better is True: - if indicator_val > self.best_metric_indicator: - self.best_metric_indicator = indicator_val - else: - is_better = False + if indicator_val < self.best_metric_indicator: + self.best_metric_indicator = indicator_val else: - if indicator_val < self.best_metric_indicator: - self.best_metric_indicator = indicator_val - else: - is_better = False - return is_better + is_better = False + return is_better DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 + def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=0): @@ -341,3 +343,34 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ # TODO 这里需要检查是否返回来的值是否是合理的 +def _check_eval_results(metrics, metric_key, metric_list): + # metrics: tester返回的结果 + # metric_key: 一个用来做筛选的指标,来自Trainer的初始化 + # metric_list: 多个用来做评价的指标,来自Trainer的初始化 + if isinstance(metrics, tuple): + loss, metrics = metrics + + if isinstance(metrics, dict): + if len(metrics) == 1: + # only single metric, just use it + metric_dict = list(metrics.values())[0] + metrics_name = list(metrics.keys())[0] + else: + metrics_name = metric_list[0].__class__.__name__ + if metrics_name not in metrics: + raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") + metric_dict = metrics[metrics_name] + + if len(metric_dict) == 1: + indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] + elif len(metric_dict) > 1 and metric_key is None: + raise RuntimeError( + f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") + else: + # metric_key is set + if metric_key not in metric_dict: + raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") + indicator_val = metric_dict[metric_key] + else: + raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) + return indicator_val