Browse Source

add test code in AdamW

tags/v0.4.10
Yige Xu 5 years ago
parent
commit
5768cbbfef
2 changed files with 16 additions and 4 deletions
  1. +6
    -3
      fastNLP/core/optimizer.py
  2. +10
    -1
      test/core/test_optimizer.py

+ 6
- 3
fastNLP/core/optimizer.py View File

@@ -33,8 +33,9 @@ class Optimizer(object):
def construct_from_pytorch(self, model_params):
raise NotImplementedError
def _get_require_grads_param(self, params):

@staticmethod
def _get_require_grads_param(params):
"""
将params中不需要gradient的删除
@@ -43,6 +44,7 @@ class Optimizer(object):
"""
return [param for param in params if param.requires_grad]


class NullOptimizer(Optimizer):
"""
当不希望Trainer更新optimizer时,传入本optimizer,但请确保通过callback的方式对参数进行了更新。
@@ -113,7 +115,8 @@ class Adam(Optimizer):

class AdamW(TorchOptimizer):
r"""
对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入
对AdamW的实现,该实现在pytorch 1.2.0版本中已经出现,https://github.com/pytorch/pytorch/pull/21250。
这里加入以适配低版本的pytorch
.. todo::
翻译成中文


+ 10
- 1
test/core/test_optimizer.py View File

@@ -2,7 +2,7 @@ import unittest

import torch

from fastNLP import SGD, Adam
from fastNLP import SGD, Adam, AdamW


class TestOptim(unittest.TestCase):
@@ -52,3 +52,12 @@ class TestOptim(unittest.TestCase):
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam))

def test_AdamW(self):
optim = AdamW(params=torch.nn.Linear(10, 3).parameters())
self.assertTrue('lr' in optim.defaults)
self.assertTrue('weight_decay' in optim.defaults)

optim = AdamW(params=torch.nn.Linear(10, 3).parameters(), lr=0.002, weight_decay=0.989)
self.assertEqual(optim.defaults['lr'], 0.002)
self.assertTrue(optim.defaults['weight_decay'], 0.989)

Loading…
Cancel
Save