|
- import torch
-
-
- class Optimizer(object):
- """
-
- :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
- :param kwargs: additional parameters.
- """
- def __init__(self, model_params, **kwargs):
- if model_params is not None and not hasattr(model_params, "__next__"):
- raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params)))
- self.model_params = model_params
- self.settings = kwargs
-
-
- class SGD(Optimizer):
- """
-
- :param float lr: learning rate. Default: 0.01
- :param float momentum: momentum. Default: 0
- :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
- """
-
- def __init__(self, lr=0.001, momentum=0, model_params=None):
- if not isinstance(lr, float):
- raise TypeError("learning rate has to be float.")
- super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)
-
- def construct_from_pytorch(self, model_params):
- if self.model_params is None:
- # careful! generator cannot be assigned.
- return torch.optim.SGD(model_params, **self.settings)
- else:
- return torch.optim.SGD(self.model_params, **self.settings)
-
-
- class Adam(Optimizer):
- """
-
- :param float lr: learning rate
- :param float weight_decay:
- :param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models.
- """
-
- def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None):
- if not isinstance(lr, float):
- raise TypeError("learning rate has to be float.")
- super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad,
- weight_decay=weight_decay)
-
- def construct_from_pytorch(self, model_params):
- if self.model_params is None:
- # careful! generator cannot be assigned.
- return torch.optim.Adam(model_params, **self.settings)
- else:
- return torch.optim.Adam(self.model_params, **self.settings)
|