From 2e96f7308ccb8a66bd53902e10af16b92ff0f6dc Mon Sep 17 00:00:00 2001 From: choosewhatulike <1901722105@qq.com> Date: Thu, 26 Jul 2018 22:38:54 +0800 Subject: [PATCH] fix optimizer --- fastNLP/action/optimizer.py | 53 +++---------------------------------- 1 file changed, 4 insertions(+), 49 deletions(-) diff --git a/fastNLP/action/optimizer.py b/fastNLP/action/optimizer.py index becdc499..b493e3f0 100644 --- a/fastNLP/action/optimizer.py +++ b/fastNLP/action/optimizer.py @@ -1,50 +1,5 @@ -from torch import optim +''' +use optimizer from Pytorch +''' - -def get_torch_optimizer(params, alg_name='sgd', **args): - """ - construct PyTorch optimizer by algorithm's name - optimizer's arguments can be specified, for different optimizer's arguments, please see PyTorch doc - - usage: - optimizer = get_torch_optimizer(model.parameters(), 'SGD', lr=0.01) - - """ - - name = alg_name.lower() - if name == 'adadelta': - return optim.Adadelta(params, **args) - elif name == 'adagrad': - return optim.Adagrad(params, **args) - elif name == 'adam': - return optim.Adam(params, **args) - elif name == 'adamax': - return optim.Adamax(params, **args) - elif name == 'asgd': - return optim.ASGD(params, **args) - elif name == 'lbfgs': - return optim.LBFGS(params, **args) - elif name == 'rmsprop': - return optim.RMSprop(params, **args) - elif name == 'rprop': - return optim.Rprop(params, **args) - elif name == 'sgd': - # SGD's parameter lr is required - if 'lr' not in args: - args['lr'] = 0.01 - return optim.SGD(params, **args) - elif name == 'sparseadam': - return optim.SparseAdam(params, **args) - else: - raise TypeError('no such optimizer named {}'.format(alg_name)) - - -if __name__ == '__main__': - from torch.nn.modules import Linear - - net = Linear(2, 5) - - test1 = get_torch_optimizer(net.parameters(), 'adam', lr=1e-2, weight_decay=1e-3) - print(test1) - test2 = get_torch_optimizer(net.parameters(), 'SGD') - print(test2) +from torch.optim import * \ No newline at end of file