|
|
@@ -2,7 +2,7 @@ import unittest |
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
from fastNLP import SGD, Adam |
|
|
|
from fastNLP import SGD, Adam, AdamW |
|
|
|
|
|
|
|
|
|
|
|
class TestOptim(unittest.TestCase): |
|
|
@@ -52,3 +52,12 @@ class TestOptim(unittest.TestCase): |
|
|
|
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) |
|
|
|
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) |
|
|
|
self.assertTrue(isinstance(res, torch.optim.Adam)) |
|
|
|
|
|
|
|
def test_AdamW(self): |
|
|
|
optim = AdamW(params=torch.nn.Linear(10, 3).parameters()) |
|
|
|
self.assertTrue('lr' in optim.defaults) |
|
|
|
self.assertTrue('weight_decay' in optim.defaults) |
|
|
|
|
|
|
|
optim = AdamW(params=torch.nn.Linear(10, 3).parameters(), lr=0.002, weight_decay=0.989) |
|
|
|
self.assertEqual(optim.defaults['lr'], 0.002) |
|
|
|
self.assertTrue(optim.defaults['weight_decay'], 0.989) |