From fb5215ae733ec50bcb6b71626db9ea7d8486a56a Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sun, 2 Dec 2018 10:58:10 +0800 Subject: [PATCH 1/2] =?UTF-8?q?fix=20bug=20in=20Trainer=20about=20metric?= =?UTF-8?q?=5Fkey=20=E6=9B=B4=E6=96=B0Optimizer:=20=E5=A4=9A=E7=A7=8D?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=96=B9=E6=B3=95=201.=20SGD()=202.?= =?UTF-8?q?=20SGD(0.01)=203.=20SGD(lr=3D0.01)=204.=20SGD(lr=3D0.01,=20mome?= =?UTF-8?q?ntum=3D0.9)=205.=20SGD(model.parameters(),=20lr=3D0.1,=20moment?= =?UTF-8?q?um=3D0.9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/optimizer.py | 58 ++++++++++++++++++++++++++++++++++--- fastNLP/core/trainer.py | 20 ++++++++----- test/core/test_optimizer.py | 43 ++++++++++++++++++++------- 3 files changed, 99 insertions(+), 22 deletions(-) diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 72737b81..4cb21462 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -3,14 +3,41 @@ import torch class Optimizer(object): def __init__(self, model_params, **kwargs): - if model_params is not None and not isinstance(model_params, torch.Tensor): - raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) + if model_params is not None and not hasattr(model_params, "__next__"): + raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params))) self.model_params = model_params self.settings = kwargs class SGD(Optimizer): - def __init__(self, model_params=None, lr=0.001, momentum=0.9): + def __init__(self, *args, **kwargs): + model_params, lr, momentum = None, 0.01, 0.9 + if len(args) == 0 and len(kwargs) == 0: + # SGD() + pass + elif len(args) == 1 and len(kwargs) == 0: + if isinstance(args[0], float) or isinstance(args[0], int): + # SGD(0.001) + lr = args[0] + elif hasattr(args[0], "__next__"): + # SGD(model.parameters()) args[0] is a generator + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + elif 2 >= len(kwargs) > 0 and len(args) <= 1: + # SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) + if len(args) == 1: + if hasattr(args[0], "__next__"): + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + if not all(key in ("lr", "momentum") for key in kwargs): + raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) + lr = kwargs.get("lr", 0.01) + momentum = kwargs.get("momentum", 0.9) + else: + raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) + super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) def construct_from_pytorch(self, model_params): @@ -20,7 +47,30 @@ class SGD(Optimizer): class Adam(Optimizer): - def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): + def __init__(self, *args, **kwargs): + model_params, lr, weight_decay = None, 0.01, 0.9 + if len(args) == 0 and len(kwargs) == 0: + pass + elif len(args) == 1 and len(kwargs) == 0: + if isinstance(args[0], float) or isinstance(args[0], int): + lr = args[0] + elif hasattr(args[0], "__next__"): + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + elif 2 >= len(kwargs) > 0 and len(args) <= 1: + if len(args) == 1: + if hasattr(args[0], "__next__"): + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + if not all(key in ("lr", "weight_decay") for key in kwargs): + raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) + lr = kwargs.get("lr", 0.01) + weight_decay = kwargs.get("weight_decay", 0.9) + else: + raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) + super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6d31e390..2a5a59e4 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -56,7 +56,10 @@ class Trainer(object): # increase_better is True. It means the exp result gets better if the indicator increases. # It is true by default. self.increase_better = False if metric_key[0] == "-" else True - self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + if metric_key is not None: + self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + else: + self.metric_key = None # prepare loss losser = _prepare_losser(losser) @@ -144,12 +147,13 @@ class Trainer(object): del self._summary_writer def _train_epoch(self, data_iterator, model, epoch, start): - """Training process in one epoch. + """ - kwargs should contain: - - n_print: int, print training information every n steps. - - start: time.time(), the starting time of this step. - - epoch: int, + :param data_iterator: + :param model: + :param epoch: + :param start: + :return: """ for batch_x, batch_y in data_iterator: # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 @@ -188,7 +192,7 @@ class Trainer(object): """Train mode or Test mode. This is for PyTorch currently. :param model: a PyTorch model - :param is_test: bool, whether in test mode or not. + :param bool is_test: whether in test mode or not. """ if is_test: @@ -263,7 +267,7 @@ class Trainer(object): else: # metric_key is set if self.metric_key not in metric_dict: - raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") + raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}") indicator_val = metric_dict[self.metric_key] is_better = True diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py index 26e47d43..ab18b9be 100644 --- a/test/core/test_optimizer.py +++ b/test/core/test_optimizer.py @@ -2,20 +2,43 @@ import unittest import torch -from fastNLP.core.optimizer import SGD +from fastNLP.core.optimizer import SGD, Adam class TestOptim(unittest.TestCase): - def test_case(self): - optim = SGD(torch.LongTensor(10)) - print(optim.__dict__) + def test_SGD(self): + optim = SGD(torch.nn.Linear(10, 3).parameters()) + self.assertTrue("lr" in optim.__dict__["settings"]) + self.assertTrue("momentum" in optim.__dict__["settings"]) - optim_2 = SGD(lr=0.001) - print(optim_2.__dict__) + optim = SGD(0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - optim_2 = SGD(lr=0.002, momentum=0.989) - print(optim_2.__dict__) + optim = SGD(lr=0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - def test_case_2(self): + optim = SGD(lr=0.002, momentum=0.989) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) + self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) + + with self.assertRaises(RuntimeError): + _ = SGD("???") with self.assertRaises(RuntimeError): - _ = SGD(0.001) + _ = SGD(0.001, lr=0.002) + with self.assertRaises(RuntimeError): + _ = SGD(lr=0.009, shit=9000) + + def test_Adam(self): + optim = Adam(torch.nn.Linear(10, 3).parameters()) + self.assertTrue("lr" in optim.__dict__["settings"]) + self.assertTrue("weight_decay" in optim.__dict__["settings"]) + + optim = Adam(0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) + + optim = Adam(lr=0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) + + optim = Adam(lr=0.002, weight_decay=0.989) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) + self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) From d74901e0379ea8cf78dd62c6f2bfaf40dee9facf Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sun, 2 Dec 2018 11:36:35 +0800 Subject: [PATCH 2/2] =?UTF-8?q?Trainer=20Update:=20*=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=B3=A8=E9=87=8A=20*=20=E4=BB=8E?= =?UTF-8?q?=5Fbetter=5Feval=5Fresult=E4=B8=AD=E6=8A=BD=E5=8F=96check=20met?= =?UTF-8?q?rics=E7=9A=84=E9=80=BB=E8=BE=91=E5=88=B0=5Fcheck=5Feval=5Fresul?= =?UTF-8?q?ts=E5=87=BD=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 123 +++++++++++++++++++++++++--------------- 1 file changed, 78 insertions(+), 45 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 2a5a59e4..78a26334 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -8,20 +8,21 @@ from tensorboardX import SummaryWriter from torch import nn from fastNLP.core.batch import Batch +from fastNLP.core.dataset import DataSet +from fastNLP.core.losses import _prepare_losser +from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.optimizer import Adam from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester -from fastNLP.core.dataset import DataSet -from fastNLP.core.losses import _prepare_losser -from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.utils import CheckError -from fastNLP.core.utils import _check_loss_evaluate -from fastNLP.core.utils import _check_forward_error from fastNLP.core.utils import _build_args +from fastNLP.core.utils import _check_forward_error +from fastNLP.core.utils import _check_loss_evaluate from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import get_func_signature + class Trainer(object): """Main Training Loop @@ -33,6 +34,30 @@ class Trainer(object): optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, metric_key=None, **kwargs): + """ + + :param DataSet train_data: the training data + :param torch.nn.modules.module model: a PyTorch model + :param LossBase losser: a loss object + :param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics + :param int n_epochs: the number of training epochs + :param int batch_size: batch size for training and validation + :param int print_every: step interval to print next training information. Default: -1(no print). + :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). + :param DataSet dev_data: the validation data + :param use_cuda: + :param str save_path: file path to save models + :param Optimizer optimizer: an optimizer object + :param int check_code_level: level of FastNLP code checker. 0: ignore. 1: warning. 2: strict. + :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one + of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets + smaller, add a `-` character in front of the string. For example + :: + metric_key="-PPL" # language model gets better as perplexity gets smaller + + :param kwargs: + + """ super(Trainer, self).__init__() if not isinstance(train_data, DataSet): @@ -64,7 +89,7 @@ class Trainer(object): # prepare loss losser = _prepare_losser(losser) - if check_code_level>-1: + if check_code_level > -1: _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, check_level=check_code_level) @@ -245,52 +270,29 @@ class Trainer(object): :return bool value: True means current results on dev set is the best. """ - if isinstance(metrics, tuple): - loss, metrics = metrics - - if isinstance(metrics, dict): - if len(metrics) == 1: - # only single metric, just use it - metric_dict = list(metrics.values())[0] - metrics_name = list(metrics.keys())[0] - else: - metrics_name = self.metrics[0].__class__.__name__ - if metrics_name not in metrics: - raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") - metric_dict = metrics[metrics_name] - - if len(metric_dict) == 1: - indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] - elif len(metric_dict) > 1 and self.metric_key is None: - raise RuntimeError( - f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") - else: - # metric_key is set - if self.metric_key not in metric_dict: - raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}") - indicator_val = metric_dict[self.metric_key] - - is_better = True - if self.best_metric_indicator is None: - # first-time validation - self.best_metric_indicator = indicator_val + indicator_val = _check_eval_results(metrics, self.metric_key, self.metrics) + is_better = True + if self.best_metric_indicator is None: + # first-time validation + self.best_metric_indicator = indicator_val + else: + if self.increase_better is True: + if indicator_val > self.best_metric_indicator: + self.best_metric_indicator = indicator_val + else: + is_better = False else: - if self.increase_better is True: - if indicator_val > self.best_metric_indicator: - self.best_metric_indicator = indicator_val - else: - is_better = False + if indicator_val < self.best_metric_indicator: + self.best_metric_indicator = indicator_val else: - if indicator_val < self.best_metric_indicator: - self.best_metric_indicator = indicator_val - else: - is_better = False - return is_better + is_better = False + return is_better DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 + def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=0): @@ -341,3 +343,34 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ # TODO 这里需要检查是否返回来的值是否是合理的 +def _check_eval_results(metrics, metric_key, metric_list): + # metrics: tester返回的结果 + # metric_key: 一个用来做筛选的指标,来自Trainer的初始化 + # metric_list: 多个用来做评价的指标,来自Trainer的初始化 + if isinstance(metrics, tuple): + loss, metrics = metrics + + if isinstance(metrics, dict): + if len(metrics) == 1: + # only single metric, just use it + metric_dict = list(metrics.values())[0] + metrics_name = list(metrics.keys())[0] + else: + metrics_name = metric_list[0].__class__.__name__ + if metrics_name not in metrics: + raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") + metric_dict = metrics[metrics_name] + + if len(metric_dict) == 1: + indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] + elif len(metric_dict) > 1 and metric_key is None: + raise RuntimeError( + f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") + else: + # metric_key is set + if metric_key not in metric_dict: + raise RuntimeError(f"metric key {metric_key} not found in {metric_dict}") + indicator_val = metric_dict[metric_key] + else: + raise RuntimeError("Invalid metrics type. Expect {}, got {}".format((tuple, dict), type(metrics))) + return indicator_val