From fb5215ae733ec50bcb6b71626db9ea7d8486a56a Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sun, 2 Dec 2018 10:58:10 +0800 Subject: [PATCH] =?UTF-8?q?fix=20bug=20in=20Trainer=20about=20metric=5Fkey?= =?UTF-8?q?=20=E6=9B=B4=E6=96=B0Optimizer:=20=E5=A4=9A=E7=A7=8D=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E6=96=B9=E6=B3=95=201.=20SGD()=202.=20SGD(0.?= =?UTF-8?q?01)=203.=20SGD(lr=3D0.01)=204.=20SGD(lr=3D0.01,=20momentum=3D0.?= =?UTF-8?q?9)=205.=20SGD(model.parameters(),=20lr=3D0.1,=20momentum=3D0.9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/optimizer.py | 58 ++++++++++++++++++++++++++++++++++--- fastNLP/core/trainer.py | 20 ++++++++----- test/core/test_optimizer.py | 43 ++++++++++++++++++++------- 3 files changed, 99 insertions(+), 22 deletions(-) diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index 72737b81..4cb21462 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -3,14 +3,41 @@ import torch class Optimizer(object): def __init__(self, model_params, **kwargs): - if model_params is not None and not isinstance(model_params, torch.Tensor): - raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) + if model_params is not None and not hasattr(model_params, "__next__"): + raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params))) self.model_params = model_params self.settings = kwargs class SGD(Optimizer): - def __init__(self, model_params=None, lr=0.001, momentum=0.9): + def __init__(self, *args, **kwargs): + model_params, lr, momentum = None, 0.01, 0.9 + if len(args) == 0 and len(kwargs) == 0: + # SGD() + pass + elif len(args) == 1 and len(kwargs) == 0: + if isinstance(args[0], float) or isinstance(args[0], int): + # SGD(0.001) + lr = args[0] + elif hasattr(args[0], "__next__"): + # SGD(model.parameters()) args[0] is a generator + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + elif 2 >= len(kwargs) > 0 and len(args) <= 1: + # SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) + if len(args) == 1: + if hasattr(args[0], "__next__"): + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + if not all(key in ("lr", "momentum") for key in kwargs): + raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) + lr = kwargs.get("lr", 0.01) + momentum = kwargs.get("momentum", 0.9) + else: + raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) + super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) def construct_from_pytorch(self, model_params): @@ -20,7 +47,30 @@ class SGD(Optimizer): class Adam(Optimizer): - def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): + def __init__(self, *args, **kwargs): + model_params, lr, weight_decay = None, 0.01, 0.9 + if len(args) == 0 and len(kwargs) == 0: + pass + elif len(args) == 1 and len(kwargs) == 0: + if isinstance(args[0], float) or isinstance(args[0], int): + lr = args[0] + elif hasattr(args[0], "__next__"): + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + elif 2 >= len(kwargs) > 0 and len(args) <= 1: + if len(args) == 1: + if hasattr(args[0], "__next__"): + model_params = args[0] + else: + raise RuntimeError("Not supported type {}.".format(type(args[0]))) + if not all(key in ("lr", "weight_decay") for key in kwargs): + raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) + lr = kwargs.get("lr", 0.01) + weight_decay = kwargs.get("weight_decay", 0.9) + else: + raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) + super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6d31e390..2a5a59e4 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -56,7 +56,10 @@ class Trainer(object): # increase_better is True. It means the exp result gets better if the indicator increases. # It is true by default. self.increase_better = False if metric_key[0] == "-" else True - self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + if metric_key is not None: + self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + else: + self.metric_key = None # prepare loss losser = _prepare_losser(losser) @@ -144,12 +147,13 @@ class Trainer(object): del self._summary_writer def _train_epoch(self, data_iterator, model, epoch, start): - """Training process in one epoch. + """ - kwargs should contain: - - n_print: int, print training information every n steps. - - start: time.time(), the starting time of this step. - - epoch: int, + :param data_iterator: + :param model: + :param epoch: + :param start: + :return: """ for batch_x, batch_y in data_iterator: # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 @@ -188,7 +192,7 @@ class Trainer(object): """Train mode or Test mode. This is for PyTorch currently. :param model: a PyTorch model - :param is_test: bool, whether in test mode or not. + :param bool is_test: whether in test mode or not. """ if is_test: @@ -263,7 +267,7 @@ class Trainer(object): else: # metric_key is set if self.metric_key not in metric_dict: - raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") + raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}") indicator_val = metric_dict[self.metric_key] is_better = True diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py index 26e47d43..ab18b9be 100644 --- a/test/core/test_optimizer.py +++ b/test/core/test_optimizer.py @@ -2,20 +2,43 @@ import unittest import torch -from fastNLP.core.optimizer import SGD +from fastNLP.core.optimizer import SGD, Adam class TestOptim(unittest.TestCase): - def test_case(self): - optim = SGD(torch.LongTensor(10)) - print(optim.__dict__) + def test_SGD(self): + optim = SGD(torch.nn.Linear(10, 3).parameters()) + self.assertTrue("lr" in optim.__dict__["settings"]) + self.assertTrue("momentum" in optim.__dict__["settings"]) - optim_2 = SGD(lr=0.001) - print(optim_2.__dict__) + optim = SGD(0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - optim_2 = SGD(lr=0.002, momentum=0.989) - print(optim_2.__dict__) + optim = SGD(lr=0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) - def test_case_2(self): + optim = SGD(lr=0.002, momentum=0.989) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) + self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) + + with self.assertRaises(RuntimeError): + _ = SGD("???") with self.assertRaises(RuntimeError): - _ = SGD(0.001) + _ = SGD(0.001, lr=0.002) + with self.assertRaises(RuntimeError): + _ = SGD(lr=0.009, shit=9000) + + def test_Adam(self): + optim = Adam(torch.nn.Linear(10, 3).parameters()) + self.assertTrue("lr" in optim.__dict__["settings"]) + self.assertTrue("weight_decay" in optim.__dict__["settings"]) + + optim = Adam(0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) + + optim = Adam(lr=0.001) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) + + optim = Adam(lr=0.002, weight_decay=0.989) + self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) + self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)