更新Optimizer: 多种初始化方法 1. SGD() 2. SGD(0.01) 3. SGD(lr=0.01) 4. SGD(lr=0.01, momentum=0.9) 5. SGD(model.parameters(), lr=0.1, momentum=0.9)tags/v0.2.0^2
@@ -3,14 +3,41 @@ import torch | |||
class Optimizer(object): | |||
def __init__(self, model_params, **kwargs): | |||
if model_params is not None and not isinstance(model_params, torch.Tensor): | |||
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) | |||
if model_params is not None and not hasattr(model_params, "__next__"): | |||
raise RuntimeError("model parameters should be a generator, rather than {}".format(type(model_params))) | |||
self.model_params = model_params | |||
self.settings = kwargs | |||
class SGD(Optimizer): | |||
def __init__(self, model_params=None, lr=0.001, momentum=0.9): | |||
def __init__(self, *args, **kwargs): | |||
model_params, lr, momentum = None, 0.01, 0.9 | |||
if len(args) == 0 and len(kwargs) == 0: | |||
# SGD() | |||
pass | |||
elif len(args) == 1 and len(kwargs) == 0: | |||
if isinstance(args[0], float) or isinstance(args[0], int): | |||
# SGD(0.001) | |||
lr = args[0] | |||
elif hasattr(args[0], "__next__"): | |||
# SGD(model.parameters()) args[0] is a generator | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||
# SGD(lr=0.01), SGD(lr=0.01, momentum=0.9), SGD(model.parameters(), lr=0.1, momentum=0.9) | |||
if len(args) == 1: | |||
if hasattr(args[0], "__next__"): | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
if not all(key in ("lr", "momentum") for key in kwargs): | |||
raise RuntimeError("Invalid SGD arguments. Expect {}, got {}.".format(("lr", "momentum"), kwargs)) | |||
lr = kwargs.get("lr", 0.01) | |||
momentum = kwargs.get("momentum", 0.9) | |||
else: | |||
raise RuntimeError("SGD only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||
def construct_from_pytorch(self, model_params): | |||
@@ -20,7 +47,30 @@ class SGD(Optimizer): | |||
class Adam(Optimizer): | |||
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): | |||
def __init__(self, *args, **kwargs): | |||
model_params, lr, weight_decay = None, 0.01, 0.9 | |||
if len(args) == 0 and len(kwargs) == 0: | |||
pass | |||
elif len(args) == 1 and len(kwargs) == 0: | |||
if isinstance(args[0], float) or isinstance(args[0], int): | |||
lr = args[0] | |||
elif hasattr(args[0], "__next__"): | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
elif 2 >= len(kwargs) > 0 and len(args) <= 1: | |||
if len(args) == 1: | |||
if hasattr(args[0], "__next__"): | |||
model_params = args[0] | |||
else: | |||
raise RuntimeError("Not supported type {}.".format(type(args[0]))) | |||
if not all(key in ("lr", "weight_decay") for key in kwargs): | |||
raise RuntimeError("Invalid Adam arguments. Expect {}, got {}.".format(("lr", "weight_decay"), kwargs)) | |||
lr = kwargs.get("lr", 0.01) | |||
weight_decay = kwargs.get("weight_decay", 0.9) | |||
else: | |||
raise RuntimeError("Adam only accept 0 or 1 sequential argument, but got {}: {}".format(len(args), args)) | |||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | |||
def construct_from_pytorch(self, model_params): | |||
@@ -56,7 +56,10 @@ class Trainer(object): | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
self.increase_better = False if metric_key[0] == "-" else True | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
if metric_key is not None: | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
else: | |||
self.metric_key = None | |||
# prepare loss | |||
losser = _prepare_losser(losser) | |||
@@ -144,12 +147,13 @@ class Trainer(object): | |||
del self._summary_writer | |||
def _train_epoch(self, data_iterator, model, epoch, start): | |||
"""Training process in one epoch. | |||
""" | |||
kwargs should contain: | |||
- n_print: int, print training information every n steps. | |||
- start: time.time(), the starting time of this step. | |||
- epoch: int, | |||
:param data_iterator: | |||
:param model: | |||
:param epoch: | |||
:param start: | |||
:return: | |||
""" | |||
for batch_x, batch_y in data_iterator: | |||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | |||
@@ -188,7 +192,7 @@ class Trainer(object): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
:param model: a PyTorch model | |||
:param is_test: bool, whether in test mode or not. | |||
:param bool is_test: whether in test mode or not. | |||
""" | |||
if is_test: | |||
@@ -263,7 +267,7 @@ class Trainer(object): | |||
else: | |||
# metric_key is set | |||
if self.metric_key not in metric_dict: | |||
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") | |||
raise RuntimeError(f"metric key {self.metric_key} not found in {metric_dict}") | |||
indicator_val = metric_dict[self.metric_key] | |||
is_better = True | |||
@@ -2,20 +2,43 @@ import unittest | |||
import torch | |||
from fastNLP.core.optimizer import SGD | |||
from fastNLP.core.optimizer import SGD, Adam | |||
class TestOptim(unittest.TestCase): | |||
def test_case(self): | |||
optim = SGD(torch.LongTensor(10)) | |||
print(optim.__dict__) | |||
def test_SGD(self): | |||
optim = SGD(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||
self.assertTrue("momentum" in optim.__dict__["settings"]) | |||
optim_2 = SGD(lr=0.001) | |||
print(optim_2.__dict__) | |||
optim = SGD(0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
optim_2 = SGD(lr=0.002, momentum=0.989) | |||
print(optim_2.__dict__) | |||
optim = SGD(lr=0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
def test_case_2(self): | |||
optim = SGD(lr=0.002, momentum=0.989) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | |||
with self.assertRaises(RuntimeError): | |||
_ = SGD("???") | |||
with self.assertRaises(RuntimeError): | |||
_ = SGD(0.001) | |||
_ = SGD(0.001, lr=0.002) | |||
with self.assertRaises(RuntimeError): | |||
_ = SGD(lr=0.009, shit=9000) | |||
def test_Adam(self): | |||
optim = Adam(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | |||
optim = Adam(0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
optim = Adam(lr=0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
optim = Adam(lr=0.002, weight_decay=0.989) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) |