|
@@ -1,50 +1,5 @@ |
|
|
from torch import optim |
|
|
|
|
|
|
|
|
''' |
|
|
|
|
|
use optimizer from Pytorch |
|
|
|
|
|
''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_torch_optimizer(params, alg_name='sgd', **args): |
|
|
|
|
|
""" |
|
|
|
|
|
construct PyTorch optimizer by algorithm's name |
|
|
|
|
|
optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc |
|
|
|
|
|
|
|
|
|
|
|
usage: |
|
|
|
|
|
optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01) |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
name = alg_name.lower() |
|
|
|
|
|
if name == 'adadelta': |
|
|
|
|
|
return optim.Adadelta(params, **args) |
|
|
|
|
|
elif name == 'adagrad': |
|
|
|
|
|
return optim.Adagrad(params, **args) |
|
|
|
|
|
elif name == 'adam': |
|
|
|
|
|
return optim.Adam(params, **args) |
|
|
|
|
|
elif name == 'adamax': |
|
|
|
|
|
return optim.Adamax(params, **args) |
|
|
|
|
|
elif name == 'asgd': |
|
|
|
|
|
return optim.ASGD(params, **args) |
|
|
|
|
|
elif name == 'lbfgs': |
|
|
|
|
|
return optim.LBFGS(params, **args) |
|
|
|
|
|
elif name == 'rmsprop': |
|
|
|
|
|
return optim.RMSprop(params, **args) |
|
|
|
|
|
elif name == 'rprop': |
|
|
|
|
|
return optim.Rprop(params, **args) |
|
|
|
|
|
elif name == 'sgd': |
|
|
|
|
|
# SGD's parameter lr is required |
|
|
|
|
|
if 'lr' not in args: |
|
|
|
|
|
args['lr'] = 0.01 |
|
|
|
|
|
return optim.SGD(params, **args) |
|
|
|
|
|
elif name == 'sparseadam': |
|
|
|
|
|
return optim.SparseAdam(params, **args) |
|
|
|
|
|
else: |
|
|
|
|
|
raise TypeError('no such optimizer named {}'.format(alg_name)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
from torch.nn.modules import Linear |
|
|
|
|
|
|
|
|
|
|
|
net = Linear(2, 5) |
|
|
|
|
|
|
|
|
|
|
|
test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3) |
|
|
|
|
|
print(test1) |
|
|
|
|
|
test2 = get_torch_optimizer(net.parameters(), 'SGD') |
|
|
|
|
|
print(test2) |
|
|
|
|
|
|
|
|
from torch.optim import * |