Browse Source

optimizer初始化调整参数顺序

tags/v0.2.0^2
yh 6 years ago
parent
commit
72877c6ed5
2 changed files with 12 additions and 7 deletions
  1. +4
    -3
      fastNLP/core/__init__.py
  2. +8
    -4
      fastNLP/core/optimizer.py

+ 4
- 3
fastNLP/core/__init__.py View File

@@ -2,9 +2,10 @@ from .batch import Batch
from .dataset import DataSet from .dataset import DataSet
from .fieldarray import FieldArray from .fieldarray import FieldArray
from .instance import Instance from .instance import Instance
from .losses import LossFromTorch
from .optimizer import Optimizer
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
from .metrics import AccuracyMetric
from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler
from .tester import Tester from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .vocabulary import Vocabulary
from .vocabulary import Vocabulary

+ 8
- 4
fastNLP/core/optimizer.py View File

@@ -10,13 +10,15 @@ class Optimizer(object):




class SGD(Optimizer): class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.01, momentum=0):
def __init__(self, lr=0.01, momentum=0, model_params=None):
""" """


:param model_params: a generator. E.g. model.parameters() for PyTorch models.
:param float lr: learning rate. Default: 0.01 :param float lr: learning rate. Default: 0.01
:param float momentum: momentum. Default: 0 :param float momentum: momentum. Default: 0
:param model_params: a generator. E.g. model.parameters() for PyTorch models.
""" """
if not isinstance(lr, float):
raise TypeError("learning rate has to be float.")
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
@@ -28,13 +30,15 @@ class SGD(Optimizer):




class Adam(Optimizer): class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.01, weight_decay=0):
def __init__(self, lr=0.01, weight_decay=0, model_params=None):
""" """


:param model_params: a generator. E.g. model.parameters() for PyTorch models.
:param float lr: learning rate :param float lr: learning rate
:param float weight_decay: :param float weight_decay:
:param model_params: a generator. E.g. model.parameters() for PyTorch models.
""" """
if not isinstance(lr, float):
raise TypeError("learning rate has to be float.")
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):


Loading…
Cancel
Save