|
@@ -10,13 +10,15 @@ class Optimizer(object): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SGD(Optimizer): |
|
|
class SGD(Optimizer): |
|
|
def __init__(self, model_params=None, lr=0.01, momentum=0): |
|
|
|
|
|
|
|
|
def __init__(self, lr=0.01, momentum=0, model_params=None): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param model_params: a generator. E.g. model.parameters() for PyTorch models. |
|
|
|
|
|
:param float lr: learning rate. Default: 0.01 |
|
|
:param float lr: learning rate. Default: 0.01 |
|
|
:param float momentum: momentum. Default: 0 |
|
|
:param float momentum: momentum. Default: 0 |
|
|
|
|
|
:param model_params: a generator. E.g. model.parameters() for PyTorch models. |
|
|
""" |
|
|
""" |
|
|
|
|
|
if not isinstance(lr, float): |
|
|
|
|
|
raise TypeError("learning rate has to be float.") |
|
|
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) |
|
|
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) |
|
|
|
|
|
|
|
|
def construct_from_pytorch(self, model_params): |
|
|
def construct_from_pytorch(self, model_params): |
|
@@ -28,13 +30,15 @@ class SGD(Optimizer): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Adam(Optimizer): |
|
|
class Adam(Optimizer): |
|
|
def __init__(self, model_params=None, lr=0.01, weight_decay=0): |
|
|
|
|
|
|
|
|
def __init__(self, lr=0.01, weight_decay=0, model_params=None): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param model_params: a generator. E.g. model.parameters() for PyTorch models. |
|
|
|
|
|
:param float lr: learning rate |
|
|
:param float lr: learning rate |
|
|
:param float weight_decay: |
|
|
:param float weight_decay: |
|
|
|
|
|
:param model_params: a generator. E.g. model.parameters() for PyTorch models. |
|
|
""" |
|
|
""" |
|
|
|
|
|
if not isinstance(lr, float): |
|
|
|
|
|
raise TypeError("learning rate has to be float.") |
|
|
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) |
|
|
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) |
|
|
|
|
|
|
|
|
def construct_from_pytorch(self, model_params): |
|
|
def construct_from_pytorch(self, model_params): |
|
|