diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index b534a72a..4d76c24e 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -33,8 +33,9 @@ class Optimizer(object): def construct_from_pytorch(self, model_params): raise NotImplementedError - - def _get_require_grads_param(self, params): + + @staticmethod + def _get_require_grads_param(params): """ 将params中不需要gradient的删除 @@ -43,6 +44,7 @@ class Optimizer(object): """ return [param for param in params if param.requires_grad] + class NullOptimizer(Optimizer): """ 当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 @@ -113,7 +115,8 @@ class Adam(Optimizer): class AdamW(TorchOptimizer): r""" - 对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 + 对AdamW的实现,该实现在pytorch 1.2.0版本中已经出现,https://github.com/pytorch/pytorch/pull/21250。 + 这里加入以适配低版本的pytorch .. todo:: 翻译成中文 diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py index b9a1c271..2f2487c7 100644 --- a/test/core/test_optimizer.py +++ b/test/core/test_optimizer.py @@ -2,7 +2,7 @@ import unittest import torch -from fastNLP import SGD, Adam +from fastNLP import SGD, Adam, AdamW class TestOptim(unittest.TestCase): @@ -52,3 +52,12 @@ class TestOptim(unittest.TestCase): self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) self.assertTrue(isinstance(res, torch.optim.Adam)) + + def test_AdamW(self): + optim = AdamW(params=torch.nn.Linear(10, 3).parameters()) + self.assertTrue('lr' in optim.defaults) + self.assertTrue('weight_decay' in optim.defaults) + + optim = AdamW(params=torch.nn.Linear(10, 3).parameters(), lr=0.002, weight_decay=0.989) + self.assertEqual(optim.defaults['lr'], 0.002) + self.assertTrue(optim.defaults['weight_decay'], 0.989)