@@ -33,8 +33,9 @@ class Optimizer(object): | |||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _get_require_grads_param(self, params): | |||||
@staticmethod | |||||
def _get_require_grads_param(params): | |||||
""" | """ | ||||
将params中不需要gradient的删除 | 将params中不需要gradient的删除 | ||||
@@ -43,6 +44,7 @@ class Optimizer(object): | |||||
""" | """ | ||||
return [param for param in params if param.requires_grad] | return [param for param in params if param.requires_grad] | ||||
class NullOptimizer(Optimizer): | class NullOptimizer(Optimizer): | ||||
""" | """ | ||||
当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | 当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。 | ||||
@@ -113,7 +115,8 @@ class Adam(Optimizer): | |||||
class AdamW(TorchOptimizer): | class AdamW(TorchOptimizer): | ||||
r""" | r""" | ||||
对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 | |||||
对AdamW的实现,该实现在pytorch 1.2.0版本中已经出现,https://github.com/pytorch/pytorch/pull/21250。 | |||||
这里加入以适配低版本的pytorch | |||||
.. todo:: | .. todo:: | ||||
翻译成中文 | 翻译成中文 | ||||
@@ -2,7 +2,7 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP import SGD, Adam | |||||
from fastNLP import SGD, Adam, AdamW | |||||
class TestOptim(unittest.TestCase): | class TestOptim(unittest.TestCase): | ||||
@@ -52,3 +52,12 @@ class TestOptim(unittest.TestCase): | |||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue(isinstance(res, torch.optim.Adam)) | self.assertTrue(isinstance(res, torch.optim.Adam)) | ||||
def test_AdamW(self): | |||||
optim = AdamW(params=torch.nn.Linear(10, 3).parameters()) | |||||
self.assertTrue('lr' in optim.defaults) | |||||
self.assertTrue('weight_decay' in optim.defaults) | |||||
optim = AdamW(params=torch.nn.Linear(10, 3).parameters(), lr=0.002, weight_decay=0.989) | |||||
self.assertEqual(optim.defaults['lr'], 0.002) | |||||
self.assertTrue(optim.defaults['weight_decay'], 0.989) |