|
|
@@ -0,0 +1,49 @@ |
|
|
|
from torch import optim |
|
|
|
|
|
|
|
def get_torch_optimizor(params, alg_name='sgd', **args): |
|
|
|
''' |
|
|
|
construct pytorch optimizor by algorithm's name |
|
|
|
optimizor's argurments can be splicified, for different optimizor's argurments, please see pytorch doc |
|
|
|
|
|
|
|
usage: |
|
|
|
optimizor = get_torch_optimizor(model.parameters(), 'SGD', lr=0.01) |
|
|
|
|
|
|
|
''' |
|
|
|
|
|
|
|
name = alg_name.lower() |
|
|
|
if name == 'adadelta': |
|
|
|
return optim.Adadelta(params, **args) |
|
|
|
elif name == 'adagrad': |
|
|
|
return optim.Adagrad(params, **args) |
|
|
|
elif name == 'adam': |
|
|
|
return optim.Adam(params, **args) |
|
|
|
elif name == 'adamax': |
|
|
|
return optim.Adamax(params, **args) |
|
|
|
elif name == 'asgd': |
|
|
|
return optim.ASGD(params, **args) |
|
|
|
elif name == 'lbfgs': |
|
|
|
return optim.LBFGS(params, **args) |
|
|
|
elif name == 'rmsprop': |
|
|
|
return optim.RMSprop(params, **args) |
|
|
|
elif name == 'rprop': |
|
|
|
return optim.Rprop(params, **args) |
|
|
|
elif name == 'sgd': |
|
|
|
#SGD's parameter lr is required |
|
|
|
if 'lr' not in args: |
|
|
|
args['lr'] = 0.01 |
|
|
|
return optim.SGD(params, **args) |
|
|
|
elif name == 'sparseadam': |
|
|
|
return optim.SparseAdam(params, **args) |
|
|
|
else: |
|
|
|
raise TypeError('no such optimizor named {}'.format(alg_name)) |
|
|
|
|
|
|
|
|
|
|
|
# example usage |
|
|
|
if __name__ == '__main__': |
|
|
|
from torch.nn.modules import Linear |
|
|
|
net = Linear(2, 5) |
|
|
|
|
|
|
|
test1 = get_torch_optimizor(net.parameters(),'adam', lr=1e-2, weight_decay=1e-3) |
|
|
|
print(test1) |
|
|
|
test2 = get_torch_optimizor(net.parameters(), 'SGD') |
|
|
|
print(test2) |