Browse Source

add optimizor construct function

tags/v0.1.0
choosewhatulike 6 years ago
parent
commit
875fdc46a5
1 changed files with 49 additions and 0 deletions
  1. +49
    -0
      fastNLP/action/optimizor.py

+ 49
- 0
fastNLP/action/optimizor.py View File

@@ -0,0 +1,49 @@
from torch import optim

def get_torch_optimizor(params, alg_name='sgd', **args):
'''
construct pytorch optimizor by algorithm's name
optimizor's argurments can be splicified, for different optimizor's argurments, please see pytorch doc

usage:
optimizor = get_torch_optimizor(model.parameters(), 'SGD', lr=0.01)

'''
name = alg_name.lower()
if name == 'adadelta':
return optim.Adadelta(params, **args)
elif name == 'adagrad':
return optim.Adagrad(params, **args)
elif name == 'adam':
return optim.Adam(params, **args)
elif name == 'adamax':
return optim.Adamax(params, **args)
elif name == 'asgd':
return optim.ASGD(params, **args)
elif name == 'lbfgs':
return optim.LBFGS(params, **args)
elif name == 'rmsprop':
return optim.RMSprop(params, **args)
elif name == 'rprop':
return optim.Rprop(params, **args)
elif name == 'sgd':
#SGD's parameter lr is required
if 'lr' not in args:
args['lr'] = 0.01
return optim.SGD(params, **args)
elif name == 'sparseadam':
return optim.SparseAdam(params, **args)
else:
raise TypeError('no such optimizor named {}'.format(alg_name))


# example usage
if __name__ == '__main__':
from torch.nn.modules import Linear
net = Linear(2, 5)

test1 = get_torch_optimizor(net.parameters(),'adam', lr=1e-2, weight_decay=1e-3)
print(test1)
test2 = get_torch_optimizor(net.parameters(), 'SGD')
print(test2)

Loading…
Cancel
Save