Browse Source

fix optimizer

tags/v0.1.0
choosewhatulike 6 years ago
parent
commit
2e96f7308c
1 changed files with 4 additions and 49 deletions
  1. +4
    -49
      fastNLP/action/optimizer.py

+ 4
- 49
fastNLP/action/optimizer.py View File

@@ -1,50 +1,5 @@
from torch import optim
'''
use optimizer from Pytorch
'''



def get_torch_optimizer(params, alg_name='sgd', **args):
"""
construct PyTorch optimizer by algorithm's name
optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc

usage:
optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01)

"""

name = alg_name.lower()
if name == 'adadelta':
return optim.Adadelta(params, **args)
elif name == 'adagrad':
return optim.Adagrad(params, **args)
elif name == 'adam':
return optim.Adam(params, **args)
elif name == 'adamax':
return optim.Adamax(params, **args)
elif name == 'asgd':
return optim.ASGD(params, **args)
elif name == 'lbfgs':
return optim.LBFGS(params, **args)
elif name == 'rmsprop':
return optim.RMSprop(params, **args)
elif name == 'rprop':
return optim.Rprop(params, **args)
elif name == 'sgd':
# SGD's parameter lr is required
if 'lr' not in args:
args['lr'] = 0.01
return optim.SGD(params, **args)
elif name == 'sparseadam':
return optim.SparseAdam(params, **args)
else:
raise TypeError('no such optimizer named {}'.format(alg_name))


if __name__ == '__main__':
from torch.nn.modules import Linear

net = Linear(2, 5)

test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3)
print(test1)
test2 = get_torch_optimizer(net.parameters(), 'SGD')
print(test2)
from torch.optim import *

Loading…
Cancel
Save